shaun3141's picture
Fix: Ensure generation_config and pad_token handling for Whisper training
c0cd25b
"""
Custom OWSM model with entity-weighted loss for Caribbean Voices challenge.
This implements loss re-weighting for proper nouns without external data.
"""
import torch
import torch.nn as nn
from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from typing import Set, Optional, Dict, Any
class OWSMWithEntityLoss(PreTrainedModel):
"""
Wrapper around OWSM model that implements weighted cross-entropy loss
to up-weight errors on entity tokens.
This model wraps the base model using composition rather than inheritance
to avoid issues with the AutoModel factory pattern.
"""
def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
"""
Args:
config: Model configuration
base_model: The instantiated base model (SpeechEncoderDecoderModel)
tokenizer: Tokenizer for converting entity words to token IDs
high_value_tokens: Set of entity words (lowercase) to up-weight
entity_weight: Multiplier for entity token errors (default: 3.0)
"""
super().__init__(config)
self.model = base_model
self.tokenizer = tokenizer
self.entity_weight = entity_weight
# Store mapping from entity word to all its token IDs
self.entity_word_to_token_ids: Dict[str, Set[int]] = {}
all_entity_token_ids = set()
print(f"Building entity token ID set from {len(high_value_tokens)} entities...")
for word in high_value_tokens:
tokens = tokenizer.tokenize(word)
if tokens:
# Get ALL token IDs for this entity word
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_id_set = set(token_ids)
self.entity_word_to_token_ids[word] = token_id_set
all_entity_token_ids.update(token_id_set)
print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
if self.entity_word_to_token_ids:
avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids)
print(f" → Average tokens per entity: {avg_tokens:.2f}")
# Pre-compute vocab_weights tensor for O(1) lookup during training
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32))
# Set entity token weights
for token_id in all_entity_token_ids:
if 0 <= token_id < vocab_size:
self.vocab_weights[token_id] = self.entity_weight
# Store for debugging
self.entity_token_ids = all_entity_token_ids
self.high_value_tokens = high_value_tokens
def get_encoder(self):
"""Delegate to sub-model's encoder."""
return self.model.get_encoder()
def get_decoder(self):
"""Delegate to sub-model's decoder."""
return self.model.get_decoder()
def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs):
"""
Forward pass that computes weighted loss if labels are provided.
Delegates to underlying model.
"""
# Filter out arguments that the base model doesn't accept
# num_items_in_batch is passed by newer transformers versions but not accepted by model
model_kwargs = {k: v for k, v in kwargs.items() if k != 'num_items_in_batch'}
outputs = self.model(
input_features=input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
return_dict=True,
**model_kwargs
)
# If we are not training or have no labels, return standard outputs
if labels is None:
return outputs
# Custom Loss Computation
logits = outputs.logits # [B, T, V]
# Flatten
# Standard CrossEntropyLoss expects [N, C] logits and [N] labels
# where N is batch_size * sequence_length
flat_logits = logits.view(-1, logits.size(-1))
flat_labels = labels.view(-1)
# Create per-token weights
# Use pre-computed weights: O(1) lookup
# labels can be -100 (ignore), we need to handle that for lookup
# Create a mask for valid labels (not -100)
valid_mask = (flat_labels != -100)
# Use padding token ID (usually 0 or 1) for lookup where label is -100
# This avoids index out of bounds. We'll mask the loss anyway.
safe_labels = flat_labels.clone()
safe_labels[~valid_mask] = 0
# Get weights
weights = self.vocab_weights[safe_labels]
# Compute unreduced loss
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(flat_logits, flat_labels)
# Apply weights
weighted_loss = loss * weights
# Apply masking (CrossEntropyLoss usually handles -100 by ignoring,
# but since we used reduction='none', we have to double check)
# The loss for -100 labels should be 0 from CrossEntropyLoss if used correctly,
# but explicit masking is safer with custom weighting.
weighted_loss = weighted_loss[valid_mask]
if weighted_loss.numel() == 0:
final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True)
else:
final_loss = weighted_loss.mean()
return Seq2SeqLMOutput(
loss=final_loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def generate(self, *args, **kwargs):
"""Delegate generation to the underlying model."""
return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""Delegate to underlying model."""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str,
tokenizer, high_value_tokens: Set[str],
entity_weight: float = 3.0, **kwargs):
"""Load pretrained OWSM model and wrap with entity-weighted loss."""
# Load the base model using the Auto class
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
# Initialize wrapper
model = cls(
config=base_model.config,
base_model=base_model,
tokenizer=tokenizer,
high_value_tokens=high_value_tokens,
entity_weight=entity_weight
)
# Copy important attributes from base model to ensure full compatibility
# with transformers components like Seq2SeqTrainer, data collators, etc.
# 1. generation_config - Required for Seq2SeqTrainer evaluation
# Seq2SeqTrainer accesses model.generation_config._from_model_config in prediction_step
if hasattr(base_model, 'generation_config') and base_model.generation_config is not None:
# Copy generation_config from base model (preferred method)
model.generation_config = base_model.generation_config
else:
# Fallback: create generation_config from model config
# This handles cases where base model doesn't have generation_config set
try:
from transformers import GenerationConfig
model.generation_config = GenerationConfig.from_model_config(model.config)
except Exception:
# If GenerationConfig.from_model_config fails, create a minimal config
# This ensures generation_config is never None, preventing AttributeError
from transformers import GenerationConfig
model.generation_config = GenerationConfig()
# 1b. Ensure generation_config uses modern task/language flags instead of deprecated forced_decoder_ids
# For Whisper models, prefer task="transcribe" and language settings over forced_decoder_ids
# Setting task/language will cause forced_decoder_ids to be ignored (as per transformers deprecation)
if hasattr(model.generation_config, 'task'):
if model.generation_config.task is None:
# Set default task for Whisper models (transcribe, not translate)
model.generation_config.task = "transcribe"
# If task is set, forced_decoder_ids will be ignored, so we can clear it to avoid warnings
if hasattr(model.generation_config, 'forced_decoder_ids') and model.generation_config.forced_decoder_ids is not None:
# Clear forced_decoder_ids when task is set to avoid deprecation warnings
model.generation_config.forced_decoder_ids = None
# 1c. Ensure pad_token_id is set in generation_config to avoid attention mask warnings
# This is important when pad_token_id == eos_token_id
if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
if hasattr(model.generation_config, 'pad_token_id'):
model.generation_config.pad_token_id = tokenizer.pad_token_id
# If base model has language set, preserve it; otherwise default to None (auto-detect)
# Note: For Caribbean Voices, we want transcription, not translation to English
# So we don't force language='en' - let the model auto-detect or use what's in config
# 2. main_input_name - Important for data collators and input handling
# e.g., "input_features" for Whisper, "input_values" for Wav2Vec2
if hasattr(base_model, 'main_input_name'):
model.main_input_name = base_model.main_input_name
# 3. Model-specific config attributes that might be set on the instance
# Note: forced_decoder_ids is deprecated in favor of task/language flags in generation_config
# We still copy it for backward compatibility, but the modern approach is preferred
for attr_name in ['forced_decoder_ids', 'suppress_tokens']:
if hasattr(base_model, attr_name):
attr_value = getattr(base_model, attr_name)
if attr_value is not None:
setattr(model, attr_name, attr_value)
return model
def save_pretrained(self, save_directory, **kwargs):
"""
Save the underlying model to the directory.
This ensures that the saved model is a standard OWSM model
that can be loaded with AutoModelForSpeechSeq2Seq for inference.
"""
print(f"Saving underlying model to {save_directory}...")
self.model.save_pretrained(save_directory, **kwargs)