File size: 11,547 Bytes
984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 e3aec0d 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 e3aec0d 984c806 e3aec0d 0ac13f6 e3aec0d 0ac13f6 984c806 e3aec0d 984c806 0ac13f6 984c806 0ac13f6 3534bf2 0ac13f6 3534bf2 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 e3aec0d 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 0ac13f6 984c806 c0cd25b 984c806 0ac13f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
"""
Custom OWSM model with entity-weighted loss for Caribbean Voices challenge.
This implements loss re-weighting for proper nouns without external data.
"""
import torch
import torch.nn as nn
from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from typing import Set, Optional, Dict, Any
class OWSMWithEntityLoss(PreTrainedModel):
"""
Wrapper around OWSM model that implements weighted cross-entropy loss
to up-weight errors on entity tokens.
This model wraps the base model using composition rather than inheritance
to avoid issues with the AutoModel factory pattern.
"""
def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
"""
Args:
config: Model configuration
base_model: The instantiated base model (SpeechEncoderDecoderModel)
tokenizer: Tokenizer for converting entity words to token IDs
high_value_tokens: Set of entity words (lowercase) to up-weight
entity_weight: Multiplier for entity token errors (default: 3.0)
"""
super().__init__(config)
self.model = base_model
self.tokenizer = tokenizer
self.entity_weight = entity_weight
# Store mapping from entity word to all its token IDs
self.entity_word_to_token_ids: Dict[str, Set[int]] = {}
all_entity_token_ids = set()
print(f"Building entity token ID set from {len(high_value_tokens)} entities...")
for word in high_value_tokens:
tokens = tokenizer.tokenize(word)
if tokens:
# Get ALL token IDs for this entity word
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_id_set = set(token_ids)
self.entity_word_to_token_ids[word] = token_id_set
all_entity_token_ids.update(token_id_set)
print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
if self.entity_word_to_token_ids:
avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids)
print(f" → Average tokens per entity: {avg_tokens:.2f}")
# Pre-compute vocab_weights tensor for O(1) lookup during training
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32))
# Set entity token weights
for token_id in all_entity_token_ids:
if 0 <= token_id < vocab_size:
self.vocab_weights[token_id] = self.entity_weight
# Store for debugging
self.entity_token_ids = all_entity_token_ids
self.high_value_tokens = high_value_tokens
def get_encoder(self):
"""Delegate to sub-model's encoder."""
return self.model.get_encoder()
def get_decoder(self):
"""Delegate to sub-model's decoder."""
return self.model.get_decoder()
def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs):
"""
Forward pass that computes weighted loss if labels are provided.
Delegates to underlying model.
"""
# Filter out arguments that the base model doesn't accept
# num_items_in_batch is passed by newer transformers versions but not accepted by model
model_kwargs = {k: v for k, v in kwargs.items() if k != 'num_items_in_batch'}
outputs = self.model(
input_features=input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
return_dict=True,
**model_kwargs
)
# If we are not training or have no labels, return standard outputs
if labels is None:
return outputs
# Custom Loss Computation
logits = outputs.logits # [B, T, V]
# Flatten
# Standard CrossEntropyLoss expects [N, C] logits and [N] labels
# where N is batch_size * sequence_length
flat_logits = logits.view(-1, logits.size(-1))
flat_labels = labels.view(-1)
# Create per-token weights
# Use pre-computed weights: O(1) lookup
# labels can be -100 (ignore), we need to handle that for lookup
# Create a mask for valid labels (not -100)
valid_mask = (flat_labels != -100)
# Use padding token ID (usually 0 or 1) for lookup where label is -100
# This avoids index out of bounds. We'll mask the loss anyway.
safe_labels = flat_labels.clone()
safe_labels[~valid_mask] = 0
# Get weights
weights = self.vocab_weights[safe_labels]
# Compute unreduced loss
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(flat_logits, flat_labels)
# Apply weights
weighted_loss = loss * weights
# Apply masking (CrossEntropyLoss usually handles -100 by ignoring,
# but since we used reduction='none', we have to double check)
# The loss for -100 labels should be 0 from CrossEntropyLoss if used correctly,
# but explicit masking is safer with custom weighting.
weighted_loss = weighted_loss[valid_mask]
if weighted_loss.numel() == 0:
final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True)
else:
final_loss = weighted_loss.mean()
return Seq2SeqLMOutput(
loss=final_loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def generate(self, *args, **kwargs):
"""Delegate generation to the underlying model."""
return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""Delegate to underlying model."""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str,
tokenizer, high_value_tokens: Set[str],
entity_weight: float = 3.0, **kwargs):
"""Load pretrained OWSM model and wrap with entity-weighted loss."""
# Load the base model using the Auto class
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
# Initialize wrapper
model = cls(
config=base_model.config,
base_model=base_model,
tokenizer=tokenizer,
high_value_tokens=high_value_tokens,
entity_weight=entity_weight
)
# Copy important attributes from base model to ensure full compatibility
# with transformers components like Seq2SeqTrainer, data collators, etc.
# 1. generation_config - Required for Seq2SeqTrainer evaluation
# Seq2SeqTrainer accesses model.generation_config._from_model_config in prediction_step
if hasattr(base_model, 'generation_config') and base_model.generation_config is not None:
# Copy generation_config from base model (preferred method)
model.generation_config = base_model.generation_config
else:
# Fallback: create generation_config from model config
# This handles cases where base model doesn't have generation_config set
try:
from transformers import GenerationConfig
model.generation_config = GenerationConfig.from_model_config(model.config)
except Exception:
# If GenerationConfig.from_model_config fails, create a minimal config
# This ensures generation_config is never None, preventing AttributeError
from transformers import GenerationConfig
model.generation_config = GenerationConfig()
# 1b. Ensure generation_config uses modern task/language flags instead of deprecated forced_decoder_ids
# For Whisper models, prefer task="transcribe" and language settings over forced_decoder_ids
# Setting task/language will cause forced_decoder_ids to be ignored (as per transformers deprecation)
if hasattr(model.generation_config, 'task'):
if model.generation_config.task is None:
# Set default task for Whisper models (transcribe, not translate)
model.generation_config.task = "transcribe"
# If task is set, forced_decoder_ids will be ignored, so we can clear it to avoid warnings
if hasattr(model.generation_config, 'forced_decoder_ids') and model.generation_config.forced_decoder_ids is not None:
# Clear forced_decoder_ids when task is set to avoid deprecation warnings
model.generation_config.forced_decoder_ids = None
# 1c. Ensure pad_token_id is set in generation_config to avoid attention mask warnings
# This is important when pad_token_id == eos_token_id
if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
if hasattr(model.generation_config, 'pad_token_id'):
model.generation_config.pad_token_id = tokenizer.pad_token_id
# If base model has language set, preserve it; otherwise default to None (auto-detect)
# Note: For Caribbean Voices, we want transcription, not translation to English
# So we don't force language='en' - let the model auto-detect or use what's in config
# 2. main_input_name - Important for data collators and input handling
# e.g., "input_features" for Whisper, "input_values" for Wav2Vec2
if hasattr(base_model, 'main_input_name'):
model.main_input_name = base_model.main_input_name
# 3. Model-specific config attributes that might be set on the instance
# Note: forced_decoder_ids is deprecated in favor of task/language flags in generation_config
# We still copy it for backward compatibility, but the modern approach is preferred
for attr_name in ['forced_decoder_ids', 'suppress_tokens']:
if hasattr(base_model, attr_name):
attr_value = getattr(base_model, attr_name)
if attr_value is not None:
setattr(model, attr_name, attr_value)
return model
def save_pretrained(self, save_directory, **kwargs):
"""
Save the underlying model to the directory.
This ensures that the saved model is a standard OWSM model
that can be loaded with AutoModelForSpeechSeq2Seq for inference.
"""
print(f"Saving underlying model to {save_directory}...")
self.model.save_pretrained(save_directory, **kwargs)
|