""" Custom OWSM model with entity-weighted loss for Caribbean Voices challenge. This implements loss re-weighting for proper nouns without external data. """ import torch import torch.nn as nn from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput from typing import Set, Optional, Dict, Any class OWSMWithEntityLoss(PreTrainedModel): """ Wrapper around OWSM model that implements weighted cross-entropy loss to up-weight errors on entity tokens. This model wraps the base model using composition rather than inheritance to avoid issues with the AutoModel factory pattern. """ def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0): """ Args: config: Model configuration base_model: The instantiated base model (SpeechEncoderDecoderModel) tokenizer: Tokenizer for converting entity words to token IDs high_value_tokens: Set of entity words (lowercase) to up-weight entity_weight: Multiplier for entity token errors (default: 3.0) """ super().__init__(config) self.model = base_model self.tokenizer = tokenizer self.entity_weight = entity_weight # Store mapping from entity word to all its token IDs self.entity_word_to_token_ids: Dict[str, Set[int]] = {} all_entity_token_ids = set() print(f"Building entity token ID set from {len(high_value_tokens)} entities...") for word in high_value_tokens: tokens = tokenizer.tokenize(word) if tokens: # Get ALL token IDs for this entity word token_ids = tokenizer.convert_tokens_to_ids(tokens) token_id_set = set(token_ids) self.entity_word_to_token_ids[word] = token_id_set all_entity_token_ids.update(token_id_set) print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs") if self.entity_word_to_token_ids: avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids) print(f" → Average tokens per entity: {avg_tokens:.2f}") # Pre-compute vocab_weights tensor for O(1) lookup during training vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer) self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32)) # Set entity token weights for token_id in all_entity_token_ids: if 0 <= token_id < vocab_size: self.vocab_weights[token_id] = self.entity_weight # Store for debugging self.entity_token_ids = all_entity_token_ids self.high_value_tokens = high_value_tokens def get_encoder(self): """Delegate to sub-model's encoder.""" return self.model.get_encoder() def get_decoder(self): """Delegate to sub-model's decoder.""" return self.model.get_decoder() def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs): """ Forward pass that computes weighted loss if labels are provided. Delegates to underlying model. """ # Filter out arguments that the base model doesn't accept # num_items_in_batch is passed by newer transformers versions but not accepted by model model_kwargs = {k: v for k, v in kwargs.items() if k != 'num_items_in_batch'} outputs = self.model( input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, labels=labels, return_dict=True, **model_kwargs ) # If we are not training or have no labels, return standard outputs if labels is None: return outputs # Custom Loss Computation logits = outputs.logits # [B, T, V] # Flatten # Standard CrossEntropyLoss expects [N, C] logits and [N] labels # where N is batch_size * sequence_length flat_logits = logits.view(-1, logits.size(-1)) flat_labels = labels.view(-1) # Create per-token weights # Use pre-computed weights: O(1) lookup # labels can be -100 (ignore), we need to handle that for lookup # Create a mask for valid labels (not -100) valid_mask = (flat_labels != -100) # Use padding token ID (usually 0 or 1) for lookup where label is -100 # This avoids index out of bounds. We'll mask the loss anyway. safe_labels = flat_labels.clone() safe_labels[~valid_mask] = 0 # Get weights weights = self.vocab_weights[safe_labels] # Compute unreduced loss loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(flat_logits, flat_labels) # Apply weights weighted_loss = loss * weights # Apply masking (CrossEntropyLoss usually handles -100 by ignoring, # but since we used reduction='none', we have to double check) # The loss for -100 labels should be 0 from CrossEntropyLoss if used correctly, # but explicit masking is safer with custom weighting. weighted_loss = weighted_loss[valid_mask] if weighted_loss.numel() == 0: final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True) else: final_loss = weighted_loss.mean() return Seq2SeqLMOutput( loss=final_loss, logits=logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def generate(self, *args, **kwargs): """Delegate generation to the underlying model.""" return self.model.generate(*args, **kwargs) def prepare_inputs_for_generation(self, *args, **kwargs): """Delegate to underlying model.""" return self.model.prepare_inputs_for_generation(*args, **kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0, **kwargs): """Load pretrained OWSM model and wrap with entity-weighted loss.""" # Load the base model using the Auto class base_model = AutoModelForSpeechSeq2Seq.from_pretrained( pretrained_model_name_or_path, **kwargs ) # Initialize wrapper model = cls( config=base_model.config, base_model=base_model, tokenizer=tokenizer, high_value_tokens=high_value_tokens, entity_weight=entity_weight ) # Copy important attributes from base model to ensure full compatibility # with transformers components like Seq2SeqTrainer, data collators, etc. # 1. generation_config - Required for Seq2SeqTrainer evaluation # Seq2SeqTrainer accesses model.generation_config._from_model_config in prediction_step if hasattr(base_model, 'generation_config') and base_model.generation_config is not None: # Copy generation_config from base model (preferred method) model.generation_config = base_model.generation_config else: # Fallback: create generation_config from model config # This handles cases where base model doesn't have generation_config set try: from transformers import GenerationConfig model.generation_config = GenerationConfig.from_model_config(model.config) except Exception: # If GenerationConfig.from_model_config fails, create a minimal config # This ensures generation_config is never None, preventing AttributeError from transformers import GenerationConfig model.generation_config = GenerationConfig() # 1b. Ensure generation_config uses modern task/language flags instead of deprecated forced_decoder_ids # For Whisper models, prefer task="transcribe" and language settings over forced_decoder_ids # Setting task/language will cause forced_decoder_ids to be ignored (as per transformers deprecation) if hasattr(model.generation_config, 'task'): if model.generation_config.task is None: # Set default task for Whisper models (transcribe, not translate) model.generation_config.task = "transcribe" # If task is set, forced_decoder_ids will be ignored, so we can clear it to avoid warnings if hasattr(model.generation_config, 'forced_decoder_ids') and model.generation_config.forced_decoder_ids is not None: # Clear forced_decoder_ids when task is set to avoid deprecation warnings model.generation_config.forced_decoder_ids = None # 1c. Ensure pad_token_id is set in generation_config to avoid attention mask warnings # This is important when pad_token_id == eos_token_id if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None: if hasattr(model.generation_config, 'pad_token_id'): model.generation_config.pad_token_id = tokenizer.pad_token_id # If base model has language set, preserve it; otherwise default to None (auto-detect) # Note: For Caribbean Voices, we want transcription, not translation to English # So we don't force language='en' - let the model auto-detect or use what's in config # 2. main_input_name - Important for data collators and input handling # e.g., "input_features" for Whisper, "input_values" for Wav2Vec2 if hasattr(base_model, 'main_input_name'): model.main_input_name = base_model.main_input_name # 3. Model-specific config attributes that might be set on the instance # Note: forced_decoder_ids is deprecated in favor of task/language flags in generation_config # We still copy it for backward compatibility, but the modern approach is preferred for attr_name in ['forced_decoder_ids', 'suppress_tokens']: if hasattr(base_model, attr_name): attr_value = getattr(base_model, attr_name) if attr_value is not None: setattr(model, attr_name, attr_value) return model def save_pretrained(self, save_directory, **kwargs): """ Save the underlying model to the directory. This ensures that the saved model is a standard OWSM model that can be loaded with AutoModelForSpeechSeq2Seq for inference. """ print(f"Saving underlying model to {save_directory}...") self.model.save_pretrained(save_directory, **kwargs)