shaun3141 commited on
Commit
1264f26
·
1 Parent(s): c41a0cd

Add configurable augmentation settings in UI and persistent logging

Browse files

- Add UI controls for speed augmentation (min/max/count) and SpecAugment parameters
- Make augmentation settings configurable before training starts
- Implement persistent logging that survives space restarts
- Add log download functionality in training UI
- Logs are saved to persistent location and can be downloaded

Files changed (3) hide show
  1. training/whisper_trainer.py +89 -22
  2. ui/interface.py +71 -4
  3. utils/logging.py +116 -0
training/whisper_trainer.py CHANGED
@@ -27,6 +27,7 @@ from training.augmentation import (
27
  get_deterministic_speed_factor_from_id,
28
  expand_dataset_with_speed_augmentation,
29
  )
 
30
 
31
  # Disable dataset caching to save disk space
32
  disable_caching()
@@ -199,7 +200,19 @@ def get_cache_key(dataset_name: str, model_name: str, split: str, seed: int) ->
199
  return hashlib.md5(cache_string.encode()).hexdigest()
200
 
201
 
202
- def prepare_whisper_dataset(dataset, processor, dataset_name: str = None, model_name: str = None, split: str = None, use_cache: bool = True):
 
 
 
 
 
 
 
 
 
 
 
 
203
  """
204
  Prepare dataset for Whisper training using Hugging Face Datasets.
205
  Supports caching to avoid reprocessing.
@@ -300,10 +313,16 @@ def prepare_whisper_dataset(dataset, processor, dataset_name: str = None, model_
300
  feat = feat[0:1]
301
 
302
  # Apply spectrogram augmentations (SpecAugment, time warping) to features (only during training)
303
- if is_training:
304
  # feat is [1, n_mels, seq_len], remove batch dim for augmentation: [n_mels, seq_len]
305
  feat_2d = feat[0] # Remove batch dimension
306
- feat_2d = apply_spectrogram_augmentations(feat_2d, apply_time_warp=True)
 
 
 
 
 
 
307
  # Add batch dimension back: [1, n_mels, seq_len]
308
  feat = np.expand_dims(feat_2d, axis=0)
309
 
@@ -400,15 +419,35 @@ def prepare_whisper_dataset(dataset, processor, dataset_name: str = None, model_
400
  return dataset
401
 
402
 
403
- def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, Optional[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
  Run Whisper training with progress tracking using HuggingFace transformers.
406
  Full integration with HuggingFace training features.
407
  """
 
 
 
408
  try:
409
  if progress:
410
  progress(0, desc="Preparing Whisper training...")
411
 
 
 
 
412
  # Check prerequisites
413
  if not os.path.exists(ENTITIES_PATH):
414
  raise FileNotFoundError(
@@ -442,22 +481,42 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
442
  train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
443
 
444
  # Expand dataset with speed augmentation (proactive augmentation)
445
- # Creates 3 versions of each sample: 0.9x, 1.0x, 1.1x speed
446
- if progress:
447
- progress(0.12, desc="Expanding dataset with speed augmentation...")
448
-
449
- print("\n" + "=" * 70)
450
- print("EXPANDING DATASET WITH SPEED AUGMENTATION")
451
- print("=" * 70)
452
- train_full = expand_dataset_with_speed_augmentation(
453
- train_full,
454
- speed_factors=[0.9, 1.0, 1.1],
455
- id_column="id",
456
- audio_column="audio",
457
- transcription_column="transcription",
458
- target_sr=TARGET_SR,
459
- )
460
- print("=" * 70 + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
  # Create train/val split AFTER expansion
463
  if progress:
@@ -524,7 +583,12 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
524
  dataset_name=HF_DATASET_NAME,
525
  model_name=WHISPER_MODEL_NAME,
526
  split="train",
527
- use_cache=True
 
 
 
 
 
528
  )
529
 
530
  if progress:
@@ -536,7 +600,8 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
536
  dataset_name=HF_DATASET_NAME,
537
  model_name=WHISPER_MODEL_NAME,
538
  split="val",
539
- use_cache=True
 
540
  )
541
 
542
  # Training arguments
@@ -649,12 +714,14 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
649
  The model is now ready for inference!
650
  """
651
 
 
652
  return success_msg, final_metrics
653
 
654
  except Exception as e:
655
  import traceback
656
  error_msg = f"❌ Error during Whisper training: {str(e)}\n\n{traceback.format_exc()}"
657
  print(error_msg)
 
658
  if progress:
659
  progress(1.0, desc="Error!")
660
  return error_msg, None
 
27
  get_deterministic_speed_factor_from_id,
28
  expand_dataset_with_speed_augmentation,
29
  )
30
+ from utils.logging import PersistentLogger, get_latest_log_file, get_log_directory
31
 
32
  # Disable dataset caching to save disk space
33
  disable_caching()
 
200
  return hashlib.md5(cache_string.encode()).hexdigest()
201
 
202
 
203
+ def prepare_whisper_dataset(
204
+ dataset,
205
+ processor,
206
+ dataset_name: str = None,
207
+ model_name: str = None,
208
+ split: str = None,
209
+ use_cache: bool = True,
210
+ specaug_enabled: bool = True,
211
+ specaug_time_mask: int = 27,
212
+ specaug_freq_mask: int = 10,
213
+ specaug_time_warp: bool = True,
214
+ specaug_warp_param: int = 40,
215
+ ):
216
  """
217
  Prepare dataset for Whisper training using Hugging Face Datasets.
218
  Supports caching to avoid reprocessing.
 
313
  feat = feat[0:1]
314
 
315
  # Apply spectrogram augmentations (SpecAugment, time warping) to features (only during training)
316
+ if is_training and specaug_enabled:
317
  # feat is [1, n_mels, seq_len], remove batch dim for augmentation: [n_mels, seq_len]
318
  feat_2d = feat[0] # Remove batch dimension
319
+ feat_2d = apply_spectrogram_augmentations(
320
+ feat_2d,
321
+ time_mask_param=specaug_time_mask,
322
+ freq_mask_param=specaug_freq_mask,
323
+ apply_time_warp=specaug_time_warp,
324
+ warp_param=specaug_warp_param,
325
+ )
326
  # Add batch dimension back: [1, n_mels, seq_len]
327
  feat = np.expand_dims(feat_2d, axis=0)
328
 
 
419
  return dataset
420
 
421
 
422
+ def run_whisper_training_progress(
423
+ epochs: int,
424
+ batch_size: int,
425
+ learning_rate: float,
426
+ speed_aug_enabled: bool = True,
427
+ speed_factor_min: float = 0.9,
428
+ speed_factor_max: float = 1.1,
429
+ speed_factor_count: int = 3,
430
+ specaug_enabled: bool = True,
431
+ specaug_time_mask: int = 27,
432
+ specaug_freq_mask: int = 10,
433
+ specaug_time_warp: bool = True,
434
+ specaug_warp_param: int = 40,
435
+ progress=None
436
+ ) -> Tuple[str, Optional[Dict[str, Any]]]:
437
  """
438
  Run Whisper training with progress tracking using HuggingFace transformers.
439
  Full integration with HuggingFace training features.
440
  """
441
+ # Set up persistent logging
442
+ logger = PersistentLogger("whisper_training")
443
+
444
  try:
445
  if progress:
446
  progress(0, desc="Preparing Whisper training...")
447
 
448
+ print(f"📝 Training logs will be saved to: {get_log_directory()}")
449
+ print(f"📝 Latest log file: {get_latest_log_file('whisper_training')}")
450
+
451
  # Check prerequisites
452
  if not os.path.exists(ENTITIES_PATH):
453
  raise FileNotFoundError(
 
481
  train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
482
 
483
  # Expand dataset with speed augmentation (proactive augmentation)
484
+ # Creates multiple versions of each sample based on speed factors
485
+ if speed_aug_enabled:
486
+ if progress:
487
+ progress(0.12, desc="Expanding dataset with speed augmentation...")
488
+
489
+ # Generate speed factors from min/max/count
490
+ if speed_factor_count == 1:
491
+ speed_factors = [1.0]
492
+ elif speed_factor_count == 2:
493
+ speed_factors = [speed_factor_min, speed_factor_max]
494
+ else:
495
+ # Generate evenly spaced factors including min, max, and intermediate values
496
+ speed_factors = [
497
+ speed_factor_min + (speed_factor_max - speed_factor_min) * i / (speed_factor_count - 1)
498
+ for i in range(speed_factor_count)
499
+ ]
500
+ # Ensure 1.0 is included if it's within range
501
+ if speed_factor_min <= 1.0 <= speed_factor_max:
502
+ speed_factors.append(1.0)
503
+ speed_factors = sorted(set(speed_factors)) # Remove duplicates and sort
504
+
505
+ print("\n" + "=" * 70)
506
+ print("EXPANDING DATASET WITH SPEED AUGMENTATION")
507
+ print("=" * 70)
508
+ print(f"Speed factors: {speed_factors}")
509
+ train_full = expand_dataset_with_speed_augmentation(
510
+ train_full,
511
+ speed_factors=speed_factors,
512
+ id_column="id",
513
+ audio_column="audio",
514
+ transcription_column="transcription",
515
+ target_sr=TARGET_SR,
516
+ )
517
+ print("=" * 70 + "\n")
518
+ else:
519
+ print("⚠ Speed augmentation disabled - using original dataset size")
520
 
521
  # Create train/val split AFTER expansion
522
  if progress:
 
583
  dataset_name=HF_DATASET_NAME,
584
  model_name=WHISPER_MODEL_NAME,
585
  split="train",
586
+ use_cache=True,
587
+ specaug_enabled=specaug_enabled,
588
+ specaug_time_mask=specaug_time_mask,
589
+ specaug_freq_mask=specaug_freq_mask,
590
+ specaug_time_warp=specaug_time_warp,
591
+ specaug_warp_param=specaug_warp_param,
592
  )
593
 
594
  if progress:
 
600
  dataset_name=HF_DATASET_NAME,
601
  model_name=WHISPER_MODEL_NAME,
602
  split="val",
603
+ use_cache=True,
604
+ specaug_enabled=False, # No augmentation for validation
605
  )
606
 
607
  # Training arguments
 
714
  The model is now ready for inference!
715
  """
716
 
717
+ logger.close()
718
  return success_msg, final_metrics
719
 
720
  except Exception as e:
721
  import traceback
722
  error_msg = f"❌ Error during Whisper training: {str(e)}\n\n{traceback.format_exc()}"
723
  print(error_msg)
724
+ logger.close()
725
  if progress:
726
  progress(1.0, desc="Error!")
727
  return error_msg, None
ui/interface.py CHANGED
@@ -1,6 +1,7 @@
1
  """Gradio UI interface for Caribbean Voices OWSM platform."""
2
  import gradio as gr
3
  import time
 
4
  from pathlib import Path
5
  from datetime import datetime
6
 
@@ -12,6 +13,7 @@ from training.whisper_trainer import run_whisper_training_progress
12
  from models.inference import transcribe_audio, run_inference_owsm
13
  from models.loader import get_available_models
14
  from data.loader import load_data_from_hf_dataset
 
15
 
16
 
17
  def create_interface():
@@ -218,19 +220,84 @@ def create_interface():
218
 
219
  with gr.Row():
220
  with gr.Column():
 
221
  whisper_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
222
  whisper_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size")
223
  whisper_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate")
224
- whisper_train_btn = gr.Button("Start Whisper Training", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  with gr.Column():
227
  whisper_train_output = gr.Markdown()
228
  whisper_train_metrics = gr.JSON(label="Training Metrics")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  whisper_train_btn.click(
231
- fn=run_whisper_training_progress,
232
- inputs=[whisper_train_epochs, whisper_train_batch_size, whisper_train_lr],
233
- outputs=[whisper_train_output, whisper_train_metrics]
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
 
236
  # Tab 5: Inference
 
1
  """Gradio UI interface for Caribbean Voices OWSM platform."""
2
  import gradio as gr
3
  import time
4
+ import os
5
  from pathlib import Path
6
  from datetime import datetime
7
 
 
13
  from models.inference import transcribe_audio, run_inference_owsm
14
  from models.loader import get_available_models
15
  from data.loader import load_data_from_hf_dataset
16
+ from utils.logging import get_latest_log_file, get_all_log_files, get_log_directory
17
 
18
 
19
  def create_interface():
 
220
 
221
  with gr.Row():
222
  with gr.Column():
223
+ gr.Markdown("#### Training Hyperparameters")
224
  whisper_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
225
  whisper_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size")
226
  whisper_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate")
227
+
228
+ gr.Markdown("#### Speed Augmentation")
229
+ gr.Markdown("Speed factors for dataset expansion (creates multiple versions of each sample)")
230
+ speed_aug_enabled = gr.Checkbox(value=True, label="Enable Speed Augmentation")
231
+ speed_factor_min = gr.Slider(0.8, 1.0, value=0.9, step=0.05, label="Min Speed Factor")
232
+ speed_factor_max = gr.Slider(1.0, 1.2, value=1.1, step=0.05, label="Max Speed Factor")
233
+ speed_factor_count = gr.Slider(2, 5, value=3, step=1, label="Number of Speed Variants")
234
+
235
+ gr.Markdown("#### SpecAugment Parameters")
236
+ gr.Markdown("Spectrogram augmentation settings (applied during training)")
237
+ specaug_enabled = gr.Checkbox(value=True, label="Enable SpecAugment")
238
+ specaug_time_mask = gr.Slider(0, 50, value=27, step=1, label="Time Mask Parameter")
239
+ specaug_freq_mask = gr.Slider(0, 20, value=10, step=1, label="Frequency Mask Parameter")
240
+ specaug_time_warp = gr.Checkbox(value=True, label="Enable Time Warping")
241
+ specaug_warp_param = gr.Slider(0, 80, value=40, step=5, label="Time Warp Parameter")
242
+
243
+ whisper_train_btn = gr.Button("Start Whisper Training", variant="primary", size="lg")
244
 
245
  with gr.Column():
246
  whisper_train_output = gr.Markdown()
247
  whisper_train_metrics = gr.JSON(label="Training Metrics")
248
+
249
+ gr.Markdown("#### Training Logs")
250
+ log_info = gr.Markdown(f"Log directory: `{get_log_directory()}`")
251
+ latest_log_file = gr.File(
252
+ label="Download Latest Training Log",
253
+ visible=False
254
+ )
255
+
256
+ def update_log_download():
257
+ latest = get_latest_log_file("whisper_training")
258
+ if latest and os.path.exists(latest):
259
+ return gr.File(value=latest, visible=True)
260
+ return gr.File(visible=False)
261
+
262
+ refresh_log_btn = gr.Button("🔄 Refresh Logs", variant="secondary", size="sm")
263
+ refresh_log_btn.click(
264
+ fn=update_log_download,
265
+ outputs=[latest_log_file]
266
+ )
267
+
268
+ def run_training_with_log_refresh(
269
+ epochs, batch_size, lr,
270
+ speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count,
271
+ specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param,
272
+ progress=gr.Progress()
273
+ ):
274
+ """Run training and refresh log download after completion."""
275
+ result = run_whisper_training_progress(
276
+ epochs, batch_size, lr,
277
+ speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count,
278
+ specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param,
279
+ progress
280
+ )
281
+ latest_log = update_log_download()
282
+ return result[0], result[1], latest_log
283
 
284
  whisper_train_btn.click(
285
+ fn=run_training_with_log_refresh,
286
+ inputs=[
287
+ whisper_train_epochs,
288
+ whisper_train_batch_size,
289
+ whisper_train_lr,
290
+ speed_aug_enabled,
291
+ speed_factor_min,
292
+ speed_factor_max,
293
+ speed_factor_count,
294
+ specaug_enabled,
295
+ specaug_time_mask,
296
+ specaug_freq_mask,
297
+ specaug_time_warp,
298
+ specaug_warp_param,
299
+ ],
300
+ outputs=[whisper_train_output, whisper_train_metrics, latest_log_file]
301
  )
302
 
303
  # Tab 5: Inference
utils/logging.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Persistent logging utility for HuggingFace Spaces.
3
+ Logs are written to files that persist across space restarts.
4
+ """
5
+ import os
6
+ import sys
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ # Try to use persistent storage if available
12
+ # HF Spaces may have /tmp or other persistent locations
13
+ PERSISTENT_LOG_DIR = None
14
+
15
+ # Try common persistent locations
16
+ for log_dir in [
17
+ "/tmp/logs", # Common temp location (may persist)
18
+ "/persistent/logs", # Some HF Spaces have this
19
+ os.path.join(os.path.expanduser("~"), ".cache", "caribbean-voices", "logs"), # User cache
20
+ ]:
21
+ try:
22
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
23
+ # Test write
24
+ test_file = os.path.join(log_dir, ".test_write")
25
+ with open(test_file, "w") as f:
26
+ f.write("test")
27
+ os.remove(test_file)
28
+ PERSISTENT_LOG_DIR = log_dir
29
+ break
30
+ except (PermissionError, OSError):
31
+ continue
32
+
33
+ # Fallback to current directory if no persistent location found
34
+ if PERSISTENT_LOG_DIR is None:
35
+ PERSISTENT_LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
36
+ Path(PERSISTENT_LOG_DIR).mkdir(parents=True, exist_ok=True)
37
+
38
+ print(f"📝 Log directory: {PERSISTENT_LOG_DIR}")
39
+
40
+
41
+ class TeeOutput:
42
+ """Tee output to both stdout and a file."""
43
+
44
+ def __init__(self, file_handle, original_stdout):
45
+ self.file = file_handle
46
+ self.stdout = original_stdout
47
+
48
+ def write(self, message):
49
+ self.stdout.write(message)
50
+ self.file.write(message)
51
+ self.file.flush()
52
+
53
+ def flush(self):
54
+ self.stdout.flush()
55
+ self.file.flush()
56
+
57
+
58
+ class PersistentLogger:
59
+ """Logger that redirects stdout/stderr to both console and persistent log files."""
60
+
61
+ def __init__(self, log_name: str = "training"):
62
+ self.log_name = log_name
63
+ self.log_file = os.path.join(PERSISTENT_LOG_DIR, f"{log_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
64
+ self.log_handle = open(self.log_file, "a", buffering=1) # Line buffered
65
+
66
+ # Save original stdout/stderr
67
+ self.original_stdout = sys.stdout
68
+ self.original_stderr = sys.stderr
69
+
70
+ # Create tee outputs
71
+ self.tee_stdout = TeeOutput(self.log_handle, self.original_stdout)
72
+ self.tee_stderr = TeeOutput(self.log_handle, self.original_stderr)
73
+
74
+ # Redirect stdout/stderr
75
+ sys.stdout = self.tee_stdout
76
+ sys.stderr = self.tee_stderr
77
+
78
+ print(f"📝 Logging to: {self.log_file}")
79
+
80
+ def close(self):
81
+ """Close the log file and restore original stdout/stderr."""
82
+ # Restore original stdout/stderr
83
+ sys.stdout = self.original_stdout
84
+ sys.stderr = self.original_stderr
85
+
86
+ # Close log file
87
+ if self.log_handle:
88
+ self.log_handle.close()
89
+ self.log_handle = None
90
+
91
+ def __enter__(self):
92
+ return self
93
+
94
+ def __exit__(self, exc_type, exc_val, exc_tb):
95
+ self.close()
96
+
97
+
98
+ def get_latest_log_file(log_name: str = "training") -> Optional[str]:
99
+ """Get the path to the latest log file for a given log name."""
100
+ log_files = list(Path(PERSISTENT_LOG_DIR).glob(f"{log_name}_*.log"))
101
+ if not log_files:
102
+ return None
103
+ # Sort by modification time, return most recent
104
+ latest = max(log_files, key=lambda p: p.stat().st_mtime)
105
+ return str(latest)
106
+
107
+
108
+ def get_all_log_files(log_name: str = "training") -> list:
109
+ """Get all log files for a given log name, sorted by modification time (newest first)."""
110
+ log_files = list(Path(PERSISTENT_LOG_DIR).glob(f"{log_name}_*.log"))
111
+ return sorted(log_files, key=lambda p: p.stat().st_mtime, reverse=True)
112
+
113
+
114
+ def get_log_directory() -> str:
115
+ """Get the log directory path."""
116
+ return PERSISTENT_LOG_DIR