caribbean-voices-hackathon / extract_entities.py
shaun3141's picture
Fix entity extraction column name mismatch
ef24863
"""
Extract high-value Caribbean entities from training transcripts.
This builds a gazetteer purely from the competition dataset (no external data).
Supports both single-word and multi-word entity extraction.
"""
import pandas as pd
import re
from typing import Set, Dict, List, Tuple
def extract_ngrams(words: List[str], n: int) -> List[Tuple[str, ...]]:
"""Extract n-grams from a list of words."""
return [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
def is_phrase_capitalized(phrase_words: List[str]) -> bool:
"""Check if a phrase has proper capitalization (all words capitalized)."""
return all(word[0].isupper() if word else False for word in phrase_words)
def extract_entities_from_transcripts(train_df: pd.DataFrame,
min_frequency: int = 50,
min_frequency_multiword: int = 20,
capitalization_threshold: float = 0.7,
verbose: bool = True) -> Set[str]:
"""
Extract high-value entities from training transcripts based on:
1. Frequency (appears > min_frequency times)
2. Capitalization pattern (capitalized/ALLCAPS most of the time)
3. Multi-word phrase detection (bigrams, trigrams)
4. Proximity to known Caribbean keywords (optional filter)
Args:
train_df: DataFrame with 'transcription' column (lowercase)
min_frequency: Minimum occurrences for single-word entities
min_frequency_multiword: Minimum occurrences for multi-word entities
capitalization_threshold: Minimum ratio of capitalized occurrences (0-1)
verbose: Print progress and statistics
"""
# Known Caribbean keywords for context filtering
caribbean_keywords = {
"caribbean", "bbc", "report", "london", "port", "prime", "minister",
"trinidad", "tobago", "jamaica", "guyana", "haiti", "barbados",
"antigua", "dominica", "grenada", "montserrat", "lucia", "kitts",
"nevis", "suriname", "caricom", "west", "indies"
}
# Exclusion list: common words that are frequently capitalized but not entities
EXCLUDED_WORDS = {
# Single letters
"i", "u",
# Titles/honorifics
"mr", "mrs", "ms", "dr", "sir", "madam",
# Days of week
"monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday",
# Months
"january", "february", "march", "april", "may", "june",
"july", "august", "september", "october", "november", "december",
# Common nouns that start sentences (remove from single words, but allow in phrases)
"minister", "assembly", "council", "department", "secretariat",
"secretary", "parliament", "congress", "labour", "republic", "states",
"attorney", "association",
# Adjectives (nationality/descriptive)
"british", "american", "cuban", "haitian", "guyanese", "jamaican",
"trinidadian", "dominican", "african", "european", "indian", "dutch",
"french", "eastern", "latin", "south", "west", "north",
}
# Exclusion patterns for multi-word phrases
EXCLUDED_PHRASE_PATTERNS = {
# Generic time references
"last week", "this week", "next week", "last year", "this year", "next year",
"last month", "this month", "next month",
# Generic titles without names
"prime minister", "foreign minister", "finance minister",
# Articles and common phrases
"the report", "the government", "the country",
}
# Track word and phrase occurrences with capitalization info
word_stats: Dict[str, Dict[str, int]] = {}
phrase_stats: Dict[str, Dict[str, int]] = {}
if verbose:
print("\n[1/3] Analyzing single words and multi-word phrases...")
# Support both 'Transcription' (CSV) and 'transcription' (HF dataset)
transcription_col = 'transcription' if 'transcription' in train_df.columns else 'Transcription'
for transcription in train_df[transcription_col]:
if pd.isna(transcription):
continue
# Tokenize: split on whitespace and punctuation
words = re.findall(r'\b[A-Za-z]+\b', str(transcription))
# === SINGLE WORD EXTRACTION ===
for word in words:
word_lower = word.lower()
if word_lower not in word_stats:
word_stats[word_lower] = {
'total': 0,
'capitalized': 0,
'allcaps': 0,
'near_caribbean': 0
}
word_stats[word_lower]['total'] += 1
# Check capitalization
if word.isupper() and len(word) > 1:
word_stats[word_lower]['allcaps'] += 1
elif word[0].isupper():
word_stats[word_lower]['capitalized'] += 1
# === MULTI-WORD EXTRACTION (bigrams and trigrams) ===
# Extract bigrams (2-word phrases)
for i in range(len(words) - 1):
phrase = (words[i], words[i+1])
phrase_lower = ' '.join(w.lower() for w in phrase)
if phrase_lower not in phrase_stats:
phrase_stats[phrase_lower] = {
'total': 0,
'capitalized': 0,
}
phrase_stats[phrase_lower]['total'] += 1
# Check if phrase is properly capitalized
if is_phrase_capitalized(phrase):
phrase_stats[phrase_lower]['capitalized'] += 1
# Extract trigrams (3-word phrases)
for i in range(len(words) - 2):
phrase = (words[i], words[i+1], words[i+2])
phrase_lower = ' '.join(w.lower() for w in phrase)
if phrase_lower not in phrase_stats:
phrase_stats[phrase_lower] = {
'total': 0,
'capitalized': 0,
}
phrase_stats[phrase_lower]['total'] += 1
# Check if phrase is properly capitalized
if is_phrase_capitalized(phrase):
phrase_stats[phrase_lower]['capitalized'] += 1
if verbose:
print(f" - Analyzed {len(word_stats):,} unique words")
print(f" - Analyzed {len(phrase_stats):,} unique phrases")
# === FILTER SINGLE WORDS ===
if verbose:
print(f"\n[2/3] Filtering single-word entities...")
single_word_entities = set()
for word_lower, stats in word_stats.items():
# Minimum length filter
if len(word_lower) < 2:
continue
# Exclusion list filter
if word_lower in EXCLUDED_WORDS:
continue
# Frequency filter
if stats['total'] < min_frequency:
continue
# Capitalization filter
capitalized_ratio = (stats['capitalized'] + stats['allcaps']) / stats['total']
if capitalized_ratio < capitalization_threshold:
continue
single_word_entities.add(word_lower)
if verbose:
print(f" - Found {len(single_word_entities)} single-word entities")
# === FILTER MULTI-WORD PHRASES ===
if verbose:
print(f"\n[3/3] Filtering multi-word entities...")
multiword_entities = set()
for phrase_lower, stats in phrase_stats.items():
# Exclusion pattern filter
if phrase_lower in EXCLUDED_PHRASE_PATTERNS:
continue
# Frequency filter (lower threshold for multi-word)
if stats['total'] < min_frequency_multiword:
continue
# Capitalization filter
capitalized_ratio = stats['capitalized'] / stats['total']
if capitalized_ratio < capitalization_threshold:
continue
# Check if all component words would be valid entities or are "rescued" by context
# For example: "puerto rico" is valid even though "puerto" alone is excluded
words_in_phrase = phrase_lower.split()
# Allow if at least one word is a known entity or if it's a common multi-word pattern
has_known_entity = any(w in single_word_entities for w in words_in_phrase)
is_common_pattern = any(pattern in phrase_lower for pattern in
['port of', 'puerto', 'st ', ' st', 'prime minister', 'saint'])
if has_known_entity or is_common_pattern:
multiword_entities.add(phrase_lower)
if verbose:
print(f" - Found {len(multiword_entities)} multi-word entities")
# === COMBINE AND DEDUPLICATE ===
# Prefer longer entities over shorter ones when they overlap
all_entities = single_word_entities | multiword_entities
if verbose:
print(f"\n✓ Total entities: {len(all_entities)} ({len(single_word_entities)} single + {len(multiword_entities)} multi-word)")
# Show top entities by type
if single_word_entities:
entity_freqs = [(w, word_stats[w]['total']) for w in single_word_entities]
entity_freqs.sort(key=lambda x: x[1], reverse=True)
print("\nTop 20 single-word entities:")
for word, freq in entity_freqs[:20]:
stats = word_stats[word]
cap_ratio = (stats['capitalized'] + stats['allcaps']) / stats['total']
print(f" {word:25s} | freq={freq:5d} | cap_ratio={cap_ratio:.2f}")
if multiword_entities:
phrase_freqs = [(p, phrase_stats[p]['total']) for p in multiword_entities]
phrase_freqs.sort(key=lambda x: x[1], reverse=True)
print("\nTop 20 multi-word entities:")
for phrase, freq in phrase_freqs[:20]:
stats = phrase_stats[phrase]
cap_ratio = stats['capitalized'] / stats['total']
print(f" {phrase:25s} | freq={freq:5d} | cap_ratio={cap_ratio:.2f}")
return all_entities