shaun3141 commited on
Commit
f931329
·
1 Parent(s): 556f0f5

Fix audio decoding: use dataset.decode() to force decode all audio before expansion

Browse files
Files changed (1) hide show
  1. training/augmentation.py +43 -21
training/augmentation.py CHANGED
@@ -307,45 +307,67 @@ def expand_dataset_with_speed_augmentation(
307
  print(f" Original size: {len(dataset):,} samples")
308
  print(f" Speed factors: {speed_factors}")
309
 
310
- # Use indexed access to ensure audio decoding happens
311
- for idx in range(len(dataset)):
312
- example = dataset[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  original_id = example.get(id_column, f"sample_{idx}")
 
 
314
  audio_data = example.get(audio_column)
315
  transcription = example.get(transcription_column, "")
316
 
317
- # Handle audio data format - decode if needed
 
318
  audio_array = None
319
  audio_sr = target_sr
320
 
321
  if audio_data is not None:
322
- # Handle dict format (most common - audio is decoded when accessed via dataset[idx])
323
  if isinstance(audio_data, dict):
324
  audio_array = audio_data.get("array")
325
  audio_sr = audio_data.get("sampling_rate", target_sr)
326
- # Handle Audio object with .array attribute
327
  elif hasattr(audio_data, "array"):
328
  audio_array = audio_data.array
329
  audio_sr = getattr(audio_data, "sampling_rate", target_sr)
330
- # Handle if it's already a numpy array
331
  elif isinstance(audio_data, np.ndarray):
332
  audio_array = audio_data
333
- # If it's still an AudioDecoder, try to access it properly
334
  elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
335
- # Try to get the decoded value - sometimes accessing via dict key helps
336
  try:
337
- # Force decoding by accessing the audio column again with proper indexing
338
- decoded_example = dataset[idx]
339
- audio_data = decoded_example.get(audio_column)
340
- if isinstance(audio_data, dict):
341
- audio_array = audio_data.get("array")
342
- audio_sr = audio_data.get("sampling_rate", target_sr)
343
- elif hasattr(audio_data, "array"):
344
- audio_array = audio_data.array
345
- audio_sr = getattr(audio_data, "sampling_rate", target_sr)
346
  except Exception as e:
347
- print(f"⚠ Warning: Failed to decode audio for sample {original_id}: {e}, skipping...")
 
 
 
348
  continue
 
 
 
 
 
 
349
 
350
  if audio_array is None:
351
  skipped_count += 1
@@ -407,7 +429,7 @@ def expand_dataset_with_speed_augmentation(
407
 
408
  # Progress update every 1000 samples
409
  if (idx + 1) % 1000 == 0:
410
- print(f" Processed {idx + 1:,}/{len(dataset):,} samples...")
411
 
412
  if len(expanded_examples) == 0:
413
  raise ValueError(
@@ -423,7 +445,7 @@ def expand_dataset_with_speed_augmentation(
423
  from datasets import Audio
424
  expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
425
 
426
- print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset):.1f}x)")
427
  if skipped_count > 0:
428
  print(f"⚠ Skipped {skipped_count} samples during expansion")
429
 
 
307
  print(f" Original size: {len(dataset):,} samples")
308
  print(f" Speed factors: {speed_factors}")
309
 
310
+ # Force decode all audio files before iterating
311
+ # This ensures AudioDecoder objects are decoded to dict format with 'array' and 'sampling_rate'
312
+ print(" Decoding audio files...")
313
+ try:
314
+ # Use decode() to force decoding of all audio files
315
+ # This converts AudioDecoder objects to dict format
316
+ dataset_decoded = dataset.decode()
317
+ print(f" ✓ Audio decoding complete")
318
+ except Exception as e:
319
+ print(f" ⚠ Warning: Failed to decode dataset: {e}")
320
+ print(f" Continuing with undecoded dataset (may be slower)...")
321
+ dataset_decoded = dataset
322
+
323
+ # Use indexed access to iterate over decoded dataset
324
+ for idx in range(len(dataset_decoded)):
325
+ example = dataset_decoded[idx]
326
  original_id = example.get(id_column, f"sample_{idx}")
327
+
328
+ # After decode(), audio should be a dict with 'array' and 'sampling_rate'
329
  audio_data = example.get(audio_column)
330
  transcription = example.get(transcription_column, "")
331
 
332
+ # Handle audio data format
333
+ # After decode(), audio_data should be a dict with 'array' and 'sampling_rate'
334
  audio_array = None
335
  audio_sr = target_sr
336
 
337
  if audio_data is not None:
 
338
  if isinstance(audio_data, dict):
339
  audio_array = audio_data.get("array")
340
  audio_sr = audio_data.get("sampling_rate", target_sr)
341
+ # Fallback: handle Audio object with .array attribute (in case decode() didn't work)
342
  elif hasattr(audio_data, "array"):
343
  audio_array = audio_data.array
344
  audio_sr = getattr(audio_data, "sampling_rate", target_sr)
345
+ # Fallback: handle numpy array directly
346
  elif isinstance(audio_data, np.ndarray):
347
  audio_array = audio_data
348
+ # If still an AudioDecoder, try to decode it manually
349
  elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
 
350
  try:
351
+ if hasattr(audio_data, "decode"):
352
+ decoded = audio_data.decode()
353
+ if isinstance(decoded, dict):
354
+ audio_array = decoded.get("array")
355
+ audio_sr = decoded.get("sampling_rate", target_sr)
356
+ elif hasattr(decoded, "array"):
357
+ audio_array = decoded.array
358
+ audio_sr = getattr(decoded, "sampling_rate", target_sr)
 
359
  except Exception as e:
360
+ print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
361
+ skipped_count += 1
362
+ if skipped_count <= 5:
363
+ print(f" Audio data type: {type(audio_data)}")
364
  continue
365
+ else:
366
+ print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
367
+ skipped_count += 1
368
+ if skipped_count <= 5:
369
+ print(f" Audio data type: {type(audio_data)}")
370
+ continue
371
 
372
  if audio_array is None:
373
  skipped_count += 1
 
429
 
430
  # Progress update every 1000 samples
431
  if (idx + 1) % 1000 == 0:
432
+ print(f" Processed {idx + 1:,}/{len(dataset_decoded):,} samples...")
433
 
434
  if len(expanded_examples) == 0:
435
  raise ValueError(
 
445
  from datasets import Audio
446
  expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
447
 
448
+ print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset_decoded):.1f}x)")
449
  if skipped_count > 0:
450
  print(f"⚠ Skipped {skipped_count} samples during expansion")
451