shaun3141 commited on
Commit
20e96fd
Β·
1 Parent(s): 28b1bc3

Separate ESPnet and Whisper training modules with clear naming

Browse files

- Created espnet_trainer.py: ESPnet-specific training (no HuggingFace fallbacks)
- Created whisper_trainer.py: Full HuggingFace transformers integration
- Updated UI with separate ESPnet and Whisper training tabs
- Fixed imports to use relative imports in training/__init__.py
- Removed old trainer.py (backed up as trainer_old.py.bak)

training/__init__.py CHANGED
@@ -1,2 +1,15 @@
1
- """Training logic for OWSM fine-tuning."""
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training modules for Caribbean Voices Hackathon.
3
 
4
+ Separate training modules:
5
+ - espnet_trainer: ESPnet-specific training (no HuggingFace dependencies)
6
+ - whisper_trainer: Whisper training with full HuggingFace integration
7
+ """
8
+
9
+ from .espnet_trainer import run_espnet_training_progress
10
+ from .whisper_trainer import run_whisper_training_progress
11
+
12
+ __all__ = [
13
+ 'run_espnet_training_progress',
14
+ 'run_whisper_training_progress',
15
+ ]
training/espnet_trainer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESPnet-specific training for OWSM models.
3
+ Uses ESPnet's native training framework - NO HuggingFace dependencies.
4
+ """
5
+ import os
6
+ import json
7
+ import torch
8
+ import numpy as np
9
+ import random
10
+ from typing import Tuple, Optional, Dict, Any
11
+ from datasets import load_dataset, Audio
12
+ from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
13
+
14
+ # Set seeds for reproducibility
15
+ SEED = 42
16
+ random.seed(SEED)
17
+ np.random.seed(SEED)
18
+ torch.manual_seed(SEED)
19
+ if torch.cuda.is_available():
20
+ torch.cuda.manual_seed_all(SEED)
21
+ torch.use_deterministic_algorithms(True, warn_only=True)
22
+
23
+ # ESPnet model configuration
24
+ ESPNET_MODEL_NAME = "espnet/owsm_v3.1_ebf_small"
25
+ TARGET_SR = 16000
26
+ MAX_AUDIO_LENGTH = 30 # seconds
27
+ HF_DATASET_NAME = "shaun3141/caribbean-voices-hackathon"
28
+
29
+
30
+ def run_espnet_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, Optional[Dict[str, Any]]]:
31
+ """
32
+ Run ESPnet OWSM training with progress tracking.
33
+ Uses ESPnet's native training framework - NO HuggingFace fallbacks.
34
+ """
35
+ try:
36
+ if progress:
37
+ progress(0, desc="Initializing ESPnet training...")
38
+
39
+ # Check ESPnet is installed - NO FALLBACKS
40
+ try:
41
+ from espnet2.bin.s2t_inference import Speech2Text
42
+ except ImportError as e:
43
+ raise RuntimeError(
44
+ f"❌ ESPnet is not installed!\n\n"
45
+ f"ESPnet is required for ESPnet model training.\n"
46
+ f"Install with: pip install espnet espnet_model_zoo\n\n"
47
+ f"Original error: {e}"
48
+ )
49
+
50
+ # Check prerequisites
51
+ if not os.path.exists(ENTITIES_PATH):
52
+ raise FileNotFoundError(
53
+ f"❌ Entities file not found at {ENTITIES_PATH}. "
54
+ f"Please extract entities first using the entity extraction tool."
55
+ )
56
+
57
+ if progress:
58
+ progress(0.05, desc="Loading entities...")
59
+ with open(ENTITIES_PATH, 'r') as f:
60
+ entities_data = json.load(f)
61
+
62
+ high_value_entities = set(entities_data['entities'])
63
+ print(f"Loaded {len(high_value_entities)} high-value entities")
64
+
65
+ if progress:
66
+ progress(0.1, desc="Loading dataset from Hugging Face...")
67
+
68
+ # Load dataset from HF
69
+ hf_token = os.getenv("HF_TOKEN")
70
+ print(f"Loading dataset: {HF_DATASET_NAME}")
71
+ dataset = load_dataset(HF_DATASET_NAME, token=hf_token)
72
+
73
+ if 'train' not in dataset:
74
+ raise ValueError(f"❌ Dataset {HF_DATASET_NAME} does not contain a 'train' split.")
75
+
76
+ train_full = dataset['train']
77
+ print(f"Loaded {len(train_full):,} total training samples")
78
+
79
+ # Cast to Audio to ensure correct sampling rate
80
+ train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
81
+
82
+ # Create train/val split
83
+ if progress:
84
+ progress(0.15, desc="Creating train/val split...")
85
+
86
+ split_dataset = train_full.train_test_split(test_size=0.1, seed=SEED)
87
+ train_dataset_raw = split_dataset['train']
88
+ val_dataset_raw = split_dataset['test']
89
+
90
+ print(f"Train: {len(train_dataset_raw):,} samples")
91
+ print(f"Val: {len(val_dataset_raw):,} samples")
92
+
93
+ # Load ESPnet model - NO FALLBACKS
94
+ if progress:
95
+ progress(0.2, desc=f"Loading ESPnet model: {ESPNET_MODEL_NAME}...")
96
+ print(f"\n{'='*70}")
97
+ print(f"Loading ESPnet model: {ESPNET_MODEL_NAME}")
98
+ print(f"{'='*70}")
99
+
100
+ espnet_model = Speech2Text.from_pretrained(ESPNET_MODEL_NAME)
101
+ print("βœ“ ESPnet model loaded successfully")
102
+
103
+ # Extract tokenizer from ESPnet model
104
+ if not hasattr(espnet_model, 'tokenizer'):
105
+ raise RuntimeError(
106
+ f"❌ ESPnet model {ESPNET_MODEL_NAME} does not have a 'tokenizer' attribute. "
107
+ f"This is required for training. The model may not be compatible with fine-tuning."
108
+ )
109
+
110
+ if espnet_model.tokenizer is None:
111
+ raise RuntimeError(
112
+ f"❌ ESPnet model {ESPNET_MODEL_NAME} has a None tokenizer. "
113
+ f"This is required for training. The model may not be properly initialized."
114
+ )
115
+
116
+ espnet_tokenizer = espnet_model.tokenizer
117
+ print("βœ“ Tokenizer extracted from ESPnet model")
118
+
119
+ # Extract ASR model
120
+ if not hasattr(espnet_model, 'asr_model'):
121
+ raise RuntimeError(
122
+ f"❌ ESPnet model {ESPNET_MODEL_NAME} does not have an 'asr_model' attribute. "
123
+ f"This is required for training. The model may not be compatible with fine-tuning."
124
+ )
125
+
126
+ if espnet_model.asr_model is None:
127
+ raise RuntimeError(
128
+ f"❌ ESPnet model {ESPNET_MODEL_NAME} has a None asr_model. "
129
+ f"This is required for training. The model may not be properly initialized."
130
+ )
131
+
132
+ espnet_asr_model = espnet_model.asr_model
133
+ print("βœ“ ASR model extracted from ESPnet")
134
+
135
+ # ESPnet training requires ESPnet recipes
136
+ # For now, we'll provide clear instructions
137
+ if progress:
138
+ progress(0.3, desc="Preparing ESPnet training setup...")
139
+
140
+ print(f"\n{'='*70}")
141
+ print("ESPnet Training Setup")
142
+ print(f"{'='*70}")
143
+ print("ESPnet models require ESPnet's native training framework.")
144
+ print("To fine-tune ESPnet models, you need to:")
145
+ print("1. Set up an ESPnet recipe (e.g., egs2/librispeech/asr1)")
146
+ print("2. Modify the recipe to use your data")
147
+ print("3. Run the ESPnet training script")
148
+ print("\nThe model and tokenizer have been loaded successfully.")
149
+ print("You can use them with ESPnet's training recipes.")
150
+ print(f"{'='*70}\n")
151
+
152
+ # Save model info for ESPnet recipes
153
+ model_info = {
154
+ 'model_name': ESPNET_MODEL_NAME,
155
+ 'entities': list(high_value_entities),
156
+ 'train_samples': len(train_dataset_raw),
157
+ 'val_samples': len(val_dataset_raw),
158
+ 'training_framework': 'espnet',
159
+ 'note': 'This model requires ESPnet native training recipes for fine-tuning'
160
+ }
161
+
162
+ os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
163
+ model_info_path = os.path.join(MODEL_OUTPUT_DIR, "espnet_model_info.json")
164
+ with open(model_info_path, 'w') as f:
165
+ json.dump(model_info, f, indent=2)
166
+
167
+ # Save entities
168
+ entities_output_path = os.path.join(MODEL_OUTPUT_DIR, "caribbean_entities.json")
169
+ with open(entities_output_path, 'w') as f:
170
+ json.dump(entities_data, f, indent=2)
171
+
172
+ if progress:
173
+ progress(1.0, desc="Complete!")
174
+
175
+ success_msg = f"""
176
+ ## βœ… ESPnet Model Loaded Successfully!
177
+
178
+ **Model:** {ESPNET_MODEL_NAME}
179
+ **Output Directory:** {MODEL_OUTPUT_DIR}
180
+
181
+ **Model Components:**
182
+ - βœ“ ESPnet Speech2Text model loaded
183
+ - βœ“ Tokenizer extracted
184
+ - βœ“ ASR model extracted
185
+
186
+ **Files Saved:**
187
+ - Model info: `{model_info_path}`
188
+ - Entities: `{entities_output_path}`
189
+
190
+ **Next Steps:**
191
+ ESPnet models require ESPnet's native training framework for fine-tuning.
192
+ Use ESPnet training recipes to fine-tune this model.
193
+
194
+ **Training Data:**
195
+ - Train samples: {len(train_dataset_raw):,}
196
+ - Val samples: {len(val_dataset_raw):,}
197
+ - Entities: {len(high_value_entities)}
198
+
199
+ **Note:** This training interface loads the ESPnet model successfully.
200
+ For actual fine-tuning, use ESPnet's training recipes.
201
+ """
202
+
203
+ return success_msg, model_info
204
+
205
+ except Exception as e:
206
+ import traceback
207
+ error_msg = f"❌ Error during ESPnet training setup: {str(e)}\n\n{traceback.format_exc()}"
208
+ print(error_msg)
209
+ if progress:
210
+ progress(1.0, desc="Error!")
211
+ return error_msg, None
training/whisper_trainer.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Whisper training using HuggingFace transformers.
3
+ Full integration with HuggingFace training features.
4
+ """
5
+ import os
6
+ import json
7
+ import torch
8
+ import numpy as np
9
+ import random
10
+ from typing import Tuple, Optional, Dict, Any
11
+ from datasets import load_dataset, Audio
12
+ from transformers import (
13
+ WhisperProcessor,
14
+ WhisperForConditionalGeneration,
15
+ Seq2SeqTrainingArguments,
16
+ Seq2SeqTrainer,
17
+ DataCollatorForSeq2Seq,
18
+ EarlyStoppingCallback,
19
+ )
20
+ from owsm_model import OWSMWithEntityLoss
21
+ from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
22
+
23
+ # Set seeds for reproducibility
24
+ SEED = 42
25
+ random.seed(SEED)
26
+ np.random.seed(SEED)
27
+ torch.manual_seed(SEED)
28
+ if torch.cuda.is_available():
29
+ torch.cuda.manual_seed_all(SEED)
30
+ torch.use_deterministic_algorithms(True, warn_only=True)
31
+
32
+ # Whisper model configuration
33
+ WHISPER_MODEL_NAME = "openai/whisper-small"
34
+ TARGET_SR = 16000
35
+ MAX_AUDIO_LENGTH = 30 # seconds
36
+ HF_DATASET_NAME = "shaun3141/caribbean-voices-hackathon"
37
+
38
+
39
+ def compute_wer_metric(predictions, labels, tokenizer):
40
+ """Compute Word Error Rate metric."""
41
+ try:
42
+ import jiwer
43
+ except ImportError:
44
+ # Fallback simple WER calculation if jiwer not available
45
+ def simple_wer(ref, hyp):
46
+ ref_words = ref.lower().split()
47
+ hyp_words = hyp.lower().split()
48
+ if len(ref_words) == 0:
49
+ return 1.0 if len(hyp_words) > 0 else 0.0
50
+
51
+ # Simple Levenshtein-like WER
52
+ ref_str = ' '.join(ref_words)
53
+ hyp_str = ' '.join(hyp_words)
54
+ if ref_str == hyp_str:
55
+ return 0.0
56
+
57
+ ref_set = set(ref_words)
58
+ hyp_set = set(hyp_words)
59
+ common = len(ref_set & hyp_set)
60
+ total_ref = len(ref_words)
61
+ return 1.0 - (common / total_ref) if total_ref > 0 else 1.0
62
+
63
+ # Decode predictions and labels
64
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
65
+
66
+ # Replace -100 with pad token for decoding
67
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
68
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
69
+
70
+ wer_scores = [simple_wer(ref, hyp) for ref, hyp in zip(decoded_labels, decoded_preds)]
71
+ return {"wer": np.mean(wer_scores)}
72
+
73
+ # Decode predictions and labels
74
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
75
+
76
+ # Replace -100 with pad token for decoding
77
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
78
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
79
+
80
+ # Compute WER using jiwer
81
+ wer = jiwer.wer(decoded_labels, decoded_preds)
82
+ return {"wer": wer}
83
+
84
+
85
+ def prepare_whisper_dataset(dataset, processor):
86
+ """
87
+ Prepare dataset for Whisper training using Hugging Face Datasets.
88
+ """
89
+
90
+ def prepare_batch(batch):
91
+ """Process a batch of examples."""
92
+ audio = batch["audio"]
93
+ transcriptions = batch["transcription"]
94
+
95
+ # Process audio with processor
96
+ inputs = processor(
97
+ [x["array"] for x in audio],
98
+ sampling_rate=TARGET_SR,
99
+ return_tensors="pt",
100
+ padding=True,
101
+ )
102
+
103
+ # Process transcriptions
104
+ with processor.as_target_processor():
105
+ labels = processor(
106
+ transcriptions,
107
+ return_tensors="pt",
108
+ padding=True,
109
+ ).input_ids
110
+
111
+ # Replace padding token id's of the labels by -100 so it's ignored by the loss
112
+ labels[labels == processor.tokenizer.pad_token_id] = -100
113
+
114
+ batch["input_features"] = inputs.input_features
115
+ batch["labels"] = labels
116
+
117
+ return batch
118
+
119
+ # Remove columns that are not needed
120
+ column_names = dataset.column_names
121
+
122
+ # Process in batches
123
+ dataset = dataset.map(
124
+ prepare_batch,
125
+ batched=True,
126
+ batch_size=16,
127
+ remove_columns=column_names,
128
+ desc="Preprocessing dataset",
129
+ )
130
+
131
+ return dataset
132
+
133
+
134
+ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, Optional[Dict[str, Any]]]:
135
+ """
136
+ Run Whisper training with progress tracking using HuggingFace transformers.
137
+ Full integration with HuggingFace training features.
138
+ """
139
+ try:
140
+ if progress:
141
+ progress(0, desc="Preparing Whisper training...")
142
+
143
+ # Check prerequisites
144
+ if not os.path.exists(ENTITIES_PATH):
145
+ raise FileNotFoundError(
146
+ f"❌ Entities file not found at {ENTITIES_PATH}. "
147
+ f"Please extract entities first using the entity extraction tool."
148
+ )
149
+
150
+ if progress:
151
+ progress(0.05, desc="Loading entities...")
152
+ with open(ENTITIES_PATH, 'r') as f:
153
+ entities_data = json.load(f)
154
+
155
+ high_value_entities = set(entities_data['entities'])
156
+ print(f"Loaded {len(high_value_entities)} high-value entities")
157
+
158
+ if progress:
159
+ progress(0.1, desc="Loading dataset from Hugging Face...")
160
+
161
+ # Load dataset from HF
162
+ hf_token = os.getenv("HF_TOKEN")
163
+ print(f"Loading dataset: {HF_DATASET_NAME}")
164
+ dataset = load_dataset(HF_DATASET_NAME, token=hf_token)
165
+
166
+ if 'train' not in dataset:
167
+ raise ValueError(f"❌ Dataset {HF_DATASET_NAME} does not contain a 'train' split.")
168
+
169
+ train_full = dataset['train']
170
+ print(f"Loaded {len(train_full):,} total training samples")
171
+
172
+ # Cast to Audio to ensure correct sampling rate
173
+ train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
174
+
175
+ # Create train/val split
176
+ if progress:
177
+ progress(0.15, desc="Creating train/val split...")
178
+
179
+ split_dataset = train_full.train_test_split(test_size=0.1, seed=SEED)
180
+ train_dataset_raw = split_dataset['train']
181
+ val_dataset_raw = split_dataset['test']
182
+
183
+ print(f"Train: {len(train_dataset_raw):,} samples")
184
+ print(f"Val: {len(val_dataset_raw):,} samples")
185
+
186
+ # Load Whisper processor
187
+ if progress:
188
+ progress(0.2, desc=f"Loading Whisper processor: {WHISPER_MODEL_NAME}...")
189
+ print(f"\nLoading Whisper processor: {WHISPER_MODEL_NAME}")
190
+
191
+ processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
192
+ print(f"βœ“ Whisper processor loaded successfully")
193
+
194
+ # Load Whisper model
195
+ if progress:
196
+ progress(0.25, desc=f"Loading Whisper model: {WHISPER_MODEL_NAME}...")
197
+ print(f"\nLoading Whisper model: {WHISPER_MODEL_NAME}")
198
+
199
+ # Use our wrapper class with entity-weighted loss
200
+ model = OWSMWithEntityLoss.from_pretrained(
201
+ WHISPER_MODEL_NAME,
202
+ tokenizer=processor.tokenizer,
203
+ high_value_tokens=high_value_entities,
204
+ entity_weight=3.0,
205
+ )
206
+
207
+ print(f"βœ“ Whisper model loaded successfully")
208
+
209
+ device = "cuda" if torch.cuda.is_available() else "cpu"
210
+ model.to(device)
211
+ print(f"Model on device: {device}")
212
+
213
+ # Prepare datasets
214
+ if progress:
215
+ progress(0.3, desc="Preprocessing training dataset...")
216
+ print("\nPreprocessing training dataset...")
217
+ train_dataset = prepare_whisper_dataset(train_dataset_raw, processor)
218
+
219
+ if progress:
220
+ progress(0.4, desc="Preprocessing validation dataset...")
221
+ print("Preprocessing validation dataset...")
222
+ val_dataset = prepare_whisper_dataset(val_dataset_raw, processor)
223
+
224
+ # Training arguments
225
+ if progress:
226
+ progress(0.5, desc="Setting up training arguments...")
227
+
228
+ training_args = Seq2SeqTrainingArguments(
229
+ output_dir=MODEL_OUTPUT_DIR,
230
+ per_device_train_batch_size=batch_size,
231
+ per_device_eval_batch_size=batch_size,
232
+ gradient_accumulation_steps=4,
233
+ learning_rate=learning_rate,
234
+ warmup_steps=500,
235
+ num_train_epochs=epochs,
236
+ evaluation_strategy="steps",
237
+ eval_steps=1000,
238
+ save_strategy="steps",
239
+ save_steps=1000,
240
+ logging_steps=100,
241
+ load_best_model_at_end=True,
242
+ metric_for_best_model="wer",
243
+ greater_is_better=False,
244
+ save_total_limit=3,
245
+ fp16=torch.cuda.is_available(),
246
+ dataloader_num_workers=4,
247
+ report_to="none",
248
+ seed=SEED,
249
+ predict_with_generate=True,
250
+ generation_max_length=200,
251
+ )
252
+
253
+ # Data collator
254
+ data_collator = DataCollatorForSeq2Seq(
255
+ processor=processor,
256
+ model=model,
257
+ padding=True,
258
+ )
259
+
260
+ # Custom compute_metrics function for WER
261
+ def compute_metrics(eval_pred):
262
+ predictions, labels = eval_pred
263
+ return compute_wer_metric(predictions, labels, processor.tokenizer)
264
+
265
+ # Trainer
266
+ trainer = Seq2SeqTrainer(
267
+ model=model,
268
+ args=training_args,
269
+ train_dataset=train_dataset,
270
+ eval_dataset=val_dataset,
271
+ data_collator=data_collator,
272
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
273
+ compute_metrics=compute_metrics,
274
+ )
275
+
276
+ # Train
277
+ if progress:
278
+ progress(0.6, desc="Starting training...")
279
+
280
+ print("\n" + "=" * 70)
281
+ print("STARTING WHISPER TRAINING")
282
+ print("=" * 70)
283
+ print(f"Model: {WHISPER_MODEL_NAME}")
284
+ print(f"Epochs: {epochs}")
285
+ print(f"Batch Size: {batch_size}")
286
+ print(f"Learning Rate: {learning_rate}")
287
+ print(f"Train Samples: {len(train_dataset):,}")
288
+ print(f"Val Samples: {len(val_dataset):,}")
289
+ print("=" * 70)
290
+
291
+ trainer.train()
292
+
293
+ # Save final model
294
+ if progress:
295
+ progress(0.95, desc="Saving model...")
296
+
297
+ print(f"\nSaving model to {MODEL_OUTPUT_DIR}...")
298
+ model.save_pretrained(MODEL_OUTPUT_DIR)
299
+ processor.save_pretrained(MODEL_OUTPUT_DIR)
300
+
301
+ # Save entities for inference
302
+ entities_output_path = os.path.join(MODEL_OUTPUT_DIR, "caribbean_entities.json")
303
+ with open(entities_output_path, 'w') as f:
304
+ json.dump(entities_data, f, indent=2)
305
+
306
+ if progress:
307
+ progress(1.0, desc="Complete!")
308
+
309
+ final_metrics = trainer.evaluate()
310
+ wer = final_metrics.get('eval_wer', 'N/A')
311
+ loss = final_metrics.get('eval_loss', 'N/A')
312
+
313
+ wer_str = f"{wer:.4f}" if isinstance(wer, (int, float)) else str(wer)
314
+ loss_str = f"{loss:.4f}" if isinstance(loss, (int, float)) else str(loss)
315
+
316
+ success_msg = f"""
317
+ ## βœ… Whisper Training Complete!
318
+
319
+ **Model:** {WHISPER_MODEL_NAME}
320
+ **Output Directory:** {MODEL_OUTPUT_DIR}
321
+
322
+ **Final Metrics:**
323
+ - Word Error Rate (WER): {wer_str}
324
+ - Validation Loss: {loss_str}
325
+
326
+ **Files Saved:**
327
+ - Model weights: `{MODEL_OUTPUT_DIR}`
328
+ - Processor: `{MODEL_OUTPUT_DIR}`
329
+ - Entities: `{entities_output_path}`
330
+
331
+ The model is now ready for inference!
332
+ """
333
+
334
+ return success_msg, final_metrics
335
+
336
+ except Exception as e:
337
+ import traceback
338
+ error_msg = f"❌ Error during Whisper training: {str(e)}\n\n{traceback.format_exc()}"
339
+ print(error_msg)
340
+ if progress:
341
+ progress(1.0, desc="Error!")
342
+ return error_msg, None
ui/interface.py CHANGED
@@ -7,7 +7,8 @@ from datetime import datetime
7
  # Import modules
8
  from utils.status import get_status_display, get_data_loading_status
9
  from utils.entities import extract_entities_progress
10
- from training.trainer import run_training_progress
 
11
  from models.inference import transcribe_audio, run_inference_owsm
12
  from models.loader import get_available_models
13
  from data.loader import load_data_from_hf_dataset
@@ -168,30 +169,69 @@ def create_interface():
168
  outputs=[extract_output, extract_json]
169
  )
170
 
171
- # Tab 4: Training
172
  with gr.Tab("πŸ‹οΈ Training"):
173
- gr.Markdown("### Fine-tune OWSM v3.1 Model")
174
  gr.Markdown("""
175
- Fine-tune OWSM v3.1 on Caribbean Voices dataset with entity-weighted loss.
176
- **Note:** Full training requires ESPnet recipes. See documentation for details.
 
177
  """)
178
 
179
- with gr.Row():
180
- with gr.Column():
181
- train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
182
- train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size")
183
- train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate")
184
- train_btn = gr.Button("Start Training", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- with gr.Column():
187
- train_output = gr.Markdown()
188
- train_metrics = gr.JSON(label="Training Metrics")
189
-
190
- train_btn.click(
191
- fn=run_training_progress,
192
- inputs=[train_epochs, train_batch_size, train_lr],
193
- outputs=[train_output, train_metrics]
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  # Tab 5: Inference
197
  with gr.Tab("πŸš€ Inference"):
@@ -261,15 +301,18 @@ def create_interface():
261
 
262
  ### Workflow
263
  1. **Extract Entities**: Run entity extraction on training data
264
- 2. **Train Model**: Fine-tune OWSM (requires ESPnet recipes)
 
 
265
  3. **Run Inference**: Generate test set transcriptions
266
  4. **Download Results**: Get submission CSV file
267
 
268
  ### Technical Details
269
- - Framework: ESPnet + PyTorch
270
- - Model: OWSM v3.1 E-Branchformer
271
- - Entity Extraction: Frequency + capitalization analysis
272
- - Training: Entity-weighted cross-entropy loss
 
273
 
274
  ### Documentation
275
  See `ESPNET_OWSM_SETUP.md` and `IMPLEMENTATION_SUMMARY.md` for details.
 
7
  # Import modules
8
  from utils.status import get_status_display, get_data_loading_status
9
  from utils.entities import extract_entities_progress
10
+ from training.espnet_trainer import run_espnet_training_progress
11
+ from training.whisper_trainer import run_whisper_training_progress
12
  from models.inference import transcribe_audio, run_inference_owsm
13
  from models.loader import get_available_models
14
  from data.loader import load_data_from_hf_dataset
 
169
  outputs=[extract_output, extract_json]
170
  )
171
 
172
+ # Tab 4: Training (with sub-tabs for ESPnet and Whisper)
173
  with gr.Tab("πŸ‹οΈ Training"):
174
+ gr.Markdown("### Model Training")
175
  gr.Markdown("""
176
+ Choose your training framework:
177
+ - **ESPnet Training**: For ESPnet OWSM models (requires ESPnet recipes)
178
+ - **Whisper Training**: For Whisper models (full HuggingFace integration)
179
  """)
180
 
181
+ with gr.Tabs() as training_tabs:
182
+ # ESPnet Training Tab
183
+ with gr.Tab("πŸ”§ ESPnet Training"):
184
+ gr.Markdown("### ESPnet OWSM Model Training")
185
+ gr.Markdown("""
186
+ **ESPnet Training** - Uses ESPnet's native framework.
187
+
188
+ This loads ESPnet models and prepares them for training with ESPnet recipes.
189
+ Full fine-tuning requires ESPnet training recipes.
190
+ """)
191
+
192
+ with gr.Row():
193
+ with gr.Column():
194
+ espnet_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs (for ESPnet recipes)")
195
+ espnet_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size (for ESPnet recipes)")
196
+ espnet_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate (for ESPnet recipes)")
197
+ espnet_train_btn = gr.Button("Load ESPnet Model", variant="primary")
198
+
199
+ with gr.Column():
200
+ espnet_train_output = gr.Markdown()
201
+ espnet_train_metrics = gr.JSON(label="Model Info")
202
+
203
+ espnet_train_btn.click(
204
+ fn=run_espnet_training_progress,
205
+ inputs=[espnet_train_epochs, espnet_train_batch_size, espnet_train_lr],
206
+ outputs=[espnet_train_output, espnet_train_metrics]
207
+ )
208
 
209
+ # Whisper Training Tab
210
+ with gr.Tab("🎀 Whisper Training"):
211
+ gr.Markdown("### Whisper Model Training")
212
+ gr.Markdown("""
213
+ **Whisper Training** - Full HuggingFace transformers integration.
214
+
215
+ Fine-tune Whisper models with entity-weighted loss using HuggingFace's training framework.
216
+ Includes full support for HuggingFace features like early stopping, WER metrics, etc.
217
+ """)
218
+
219
+ with gr.Row():
220
+ with gr.Column():
221
+ whisper_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
222
+ whisper_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size")
223
+ whisper_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate")
224
+ whisper_train_btn = gr.Button("Start Whisper Training", variant="primary")
225
+
226
+ with gr.Column():
227
+ whisper_train_output = gr.Markdown()
228
+ whisper_train_metrics = gr.JSON(label="Training Metrics")
229
+
230
+ whisper_train_btn.click(
231
+ fn=run_whisper_training_progress,
232
+ inputs=[whisper_train_epochs, whisper_train_batch_size, whisper_train_lr],
233
+ outputs=[whisper_train_output, whisper_train_metrics]
234
+ )
235
 
236
  # Tab 5: Inference
237
  with gr.Tab("πŸš€ Inference"):
 
301
 
302
  ### Workflow
303
  1. **Extract Entities**: Run entity extraction on training data
304
+ 2. **Train Model**:
305
+ - **ESPnet Training**: Load ESPnet models (requires ESPnet recipes for fine-tuning)
306
+ - **Whisper Training**: Full HuggingFace fine-tuning with entity-weighted loss
307
  3. **Run Inference**: Generate test set transcriptions
308
  4. **Download Results**: Get submission CSV file
309
 
310
  ### Technical Details
311
+ - **ESPnet Framework**: ESPnet + PyTorch for ESPnet OWSM models
312
+ - **Whisper Framework**: HuggingFace transformers for Whisper models
313
+ - **Model**: OWSM v3.1 E-Branchformer (ESPnet) or Whisper (HuggingFace)
314
+ - **Entity Extraction**: Frequency + capitalization analysis
315
+ - **Training**: Entity-weighted cross-entropy loss
316
 
317
  ### Documentation
318
  See `ESPNET_OWSM_SETUP.md` and `IMPLEMENTATION_SUMMARY.md` for details.