|
|
""" |
|
|
Custom OWSM model with entity-weighted loss for Caribbean Voices challenge. |
|
|
This implements loss re-weighting for proper nouns without external data. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
from typing import Set, Optional, Dict, Any |
|
|
|
|
|
class OWSMWithEntityLoss(PreTrainedModel): |
|
|
""" |
|
|
Wrapper around OWSM model that implements weighted cross-entropy loss |
|
|
to up-weight errors on entity tokens. |
|
|
|
|
|
This model wraps the base model using composition rather than inheritance |
|
|
to avoid issues with the AutoModel factory pattern. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0): |
|
|
""" |
|
|
Args: |
|
|
config: Model configuration |
|
|
base_model: The instantiated base model (SpeechEncoderDecoderModel) |
|
|
tokenizer: Tokenizer for converting entity words to token IDs |
|
|
high_value_tokens: Set of entity words (lowercase) to up-weight |
|
|
entity_weight: Multiplier for entity token errors (default: 3.0) |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.model = base_model |
|
|
self.tokenizer = tokenizer |
|
|
self.entity_weight = entity_weight |
|
|
|
|
|
|
|
|
self.entity_word_to_token_ids: Dict[str, Set[int]] = {} |
|
|
all_entity_token_ids = set() |
|
|
|
|
|
print(f"Building entity token ID set from {len(high_value_tokens)} entities...") |
|
|
for word in high_value_tokens: |
|
|
tokens = tokenizer.tokenize(word) |
|
|
if tokens: |
|
|
|
|
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
token_id_set = set(token_ids) |
|
|
self.entity_word_to_token_ids[word] = token_id_set |
|
|
all_entity_token_ids.update(token_id_set) |
|
|
|
|
|
print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs") |
|
|
if self.entity_word_to_token_ids: |
|
|
avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids) |
|
|
print(f" → Average tokens per entity: {avg_tokens:.2f}") |
|
|
|
|
|
|
|
|
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer) |
|
|
self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32)) |
|
|
|
|
|
|
|
|
for token_id in all_entity_token_ids: |
|
|
if 0 <= token_id < vocab_size: |
|
|
self.vocab_weights[token_id] = self.entity_weight |
|
|
|
|
|
|
|
|
self.entity_token_ids = all_entity_token_ids |
|
|
self.high_value_tokens = high_value_tokens |
|
|
|
|
|
def get_encoder(self): |
|
|
"""Delegate to sub-model's encoder.""" |
|
|
return self.model.get_encoder() |
|
|
|
|
|
def get_decoder(self): |
|
|
"""Delegate to sub-model's decoder.""" |
|
|
return self.model.get_decoder() |
|
|
|
|
|
def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs): |
|
|
""" |
|
|
Forward pass that computes weighted loss if labels are provided. |
|
|
Delegates to underlying model. |
|
|
""" |
|
|
|
|
|
|
|
|
model_kwargs = {k: v for k, v in kwargs.items() if k != 'num_items_in_batch'} |
|
|
|
|
|
outputs = self.model( |
|
|
input_features=input_features, |
|
|
attention_mask=attention_mask, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
labels=labels, |
|
|
return_dict=True, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if labels is None: |
|
|
return outputs |
|
|
|
|
|
|
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flat_logits = logits.view(-1, logits.size(-1)) |
|
|
flat_labels = labels.view(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_mask = (flat_labels != -100) |
|
|
|
|
|
|
|
|
|
|
|
safe_labels = flat_labels.clone() |
|
|
safe_labels[~valid_mask] = 0 |
|
|
|
|
|
|
|
|
weights = self.vocab_weights[safe_labels] |
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction="none") |
|
|
loss = loss_fct(flat_logits, flat_labels) |
|
|
|
|
|
|
|
|
weighted_loss = loss * weights |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weighted_loss = weighted_loss[valid_mask] |
|
|
|
|
|
if weighted_loss.numel() == 0: |
|
|
final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True) |
|
|
else: |
|
|
final_loss = weighted_loss.mean() |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=final_loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
|
decoder_attentions=outputs.decoder_attentions, |
|
|
cross_attentions=outputs.cross_attentions, |
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
|
encoder_attentions=outputs.encoder_attentions, |
|
|
) |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
|
"""Delegate generation to the underlying model.""" |
|
|
return self.model.generate(*args, **kwargs) |
|
|
|
|
|
def prepare_inputs_for_generation(self, *args, **kwargs): |
|
|
"""Delegate to underlying model.""" |
|
|
return self.model.prepare_inputs_for_generation(*args, **kwargs) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, |
|
|
tokenizer, high_value_tokens: Set[str], |
|
|
entity_weight: float = 3.0, **kwargs): |
|
|
"""Load pretrained OWSM model and wrap with entity-weighted loss.""" |
|
|
|
|
|
base_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
pretrained_model_name_or_path, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
model = cls( |
|
|
config=base_model.config, |
|
|
base_model=base_model, |
|
|
tokenizer=tokenizer, |
|
|
high_value_tokens=high_value_tokens, |
|
|
entity_weight=entity_weight |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(base_model, 'generation_config') and base_model.generation_config is not None: |
|
|
|
|
|
model.generation_config = base_model.generation_config |
|
|
else: |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import GenerationConfig |
|
|
model.generation_config = GenerationConfig.from_model_config(model.config) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
from transformers import GenerationConfig |
|
|
model.generation_config = GenerationConfig() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(model.generation_config, 'task'): |
|
|
if model.generation_config.task is None: |
|
|
|
|
|
model.generation_config.task = "transcribe" |
|
|
|
|
|
if hasattr(model.generation_config, 'forced_decoder_ids') and model.generation_config.forced_decoder_ids is not None: |
|
|
|
|
|
model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None: |
|
|
if hasattr(model.generation_config, 'pad_token_id'): |
|
|
model.generation_config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(base_model, 'main_input_name'): |
|
|
model.main_input_name = base_model.main_input_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for attr_name in ['forced_decoder_ids', 'suppress_tokens']: |
|
|
if hasattr(base_model, attr_name): |
|
|
attr_value = getattr(base_model, attr_name) |
|
|
if attr_value is not None: |
|
|
setattr(model, attr_name, attr_value) |
|
|
|
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
""" |
|
|
Save the underlying model to the directory. |
|
|
This ensures that the saved model is a standard OWSM model |
|
|
that can be loaded with AutoModelForSpeechSeq2Seq for inference. |
|
|
""" |
|
|
print(f"Saving underlying model to {save_directory}...") |
|
|
self.model.save_pretrained(save_directory, **kwargs) |
|
|
|