shaun3141 commited on
Commit
c41a0cd
·
1 Parent(s): f931329

Fix AudioDecoder handling: use get_all_samples() method to extract audio data from AudioDecoder objects

Browse files
Files changed (1) hide show
  1. training/augmentation.py +60 -37
training/augmentation.py CHANGED
@@ -307,66 +307,89 @@ def expand_dataset_with_speed_augmentation(
307
  print(f" Original size: {len(dataset):,} samples")
308
  print(f" Speed factors: {speed_factors}")
309
 
310
- # Force decode all audio files before iterating
311
- # This ensures AudioDecoder objects are decoded to dict format with 'array' and 'sampling_rate'
312
- print(" Decoding audio files...")
313
- try:
314
- # Use decode() to force decoding of all audio files
315
- # This converts AudioDecoder objects to dict format
316
- dataset_decoded = dataset.decode()
317
- print(f" ✓ Audio decoding complete")
318
- except Exception as e:
319
- print(f" ⚠ Warning: Failed to decode dataset: {e}")
320
- print(f" Continuing with undecoded dataset (may be slower)...")
321
- dataset_decoded = dataset
322
-
323
- # Use indexed access to iterate over decoded dataset
324
- for idx in range(len(dataset_decoded)):
325
- example = dataset_decoded[idx]
326
  original_id = example.get(id_column, f"sample_{idx}")
327
 
328
- # After decode(), audio should be a dict with 'array' and 'sampling_rate'
329
- audio_data = example.get(audio_column)
 
 
 
 
 
 
 
 
330
  transcription = example.get(transcription_column, "")
331
 
332
  # Handle audio data format
333
- # After decode(), audio_data should be a dict with 'array' and 'sampling_rate'
334
  audio_array = None
335
  audio_sr = target_sr
336
 
337
  if audio_data is not None:
 
338
  if isinstance(audio_data, dict):
339
  audio_array = audio_data.get("array")
340
  audio_sr = audio_data.get("sampling_rate", target_sr)
341
- # Fallback: handle Audio object with .array attribute (in case decode() didn't work)
342
- elif hasattr(audio_data, "array"):
343
- audio_array = audio_data.array
344
- audio_sr = getattr(audio_data, "sampling_rate", target_sr)
345
- # Fallback: handle numpy array directly
346
- elif isinstance(audio_data, np.ndarray):
347
- audio_array = audio_data
348
- # If still an AudioDecoder, try to decode it manually
349
  elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
350
  try:
351
- if hasattr(audio_data, "decode"):
352
- decoded = audio_data.decode()
353
- if isinstance(decoded, dict):
354
- audio_array = decoded.get("array")
355
- audio_sr = decoded.get("sampling_rate", target_sr)
356
- elif hasattr(decoded, "array"):
357
- audio_array = decoded.array
358
- audio_sr = getattr(decoded, "sampling_rate", target_sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  except Exception as e:
360
  print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
361
  skipped_count += 1
362
  if skipped_count <= 5:
363
  print(f" Audio data type: {type(audio_data)}")
 
 
364
  continue
 
 
 
 
 
 
 
365
  else:
366
  print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
367
  skipped_count += 1
368
  if skipped_count <= 5:
369
  print(f" Audio data type: {type(audio_data)}")
 
370
  continue
371
 
372
  if audio_array is None:
@@ -429,7 +452,7 @@ def expand_dataset_with_speed_augmentation(
429
 
430
  # Progress update every 1000 samples
431
  if (idx + 1) % 1000 == 0:
432
- print(f" Processed {idx + 1:,}/{len(dataset_decoded):,} samples...")
433
 
434
  if len(expanded_examples) == 0:
435
  raise ValueError(
@@ -445,7 +468,7 @@ def expand_dataset_with_speed_augmentation(
445
  from datasets import Audio
446
  expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
447
 
448
- print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset_decoded):.1f}x)")
449
  if skipped_count > 0:
450
  print(f"⚠ Skipped {skipped_count} samples during expansion")
451
 
 
307
  print(f" Original size: {len(dataset):,} samples")
308
  print(f" Speed factors: {speed_factors}")
309
 
310
+ # Access audio column directly using bracket notation to trigger decoding
311
+ # HuggingFace datasets automatically decode AudioDecoder when accessed via bracket notation
312
+ print(" Processing audio files (decoding on-the-fly)...")
313
+
314
+ # Use indexed access to iterate over dataset
315
+ for idx in range(len(dataset)):
316
+ example = dataset[idx]
 
 
 
 
 
 
 
 
 
317
  original_id = example.get(id_column, f"sample_{idx}")
318
 
319
+ # Access audio column using bracket notation - this should trigger automatic decoding
320
+ # Try bracket access first (triggers decoding), fallback to .get() if needed
321
+ try:
322
+ audio_data = example[audio_column] # Bracket access triggers decoding
323
+ except (KeyError, TypeError):
324
+ try:
325
+ audio_data = example.get(audio_column)
326
+ except:
327
+ audio_data = None
328
+
329
  transcription = example.get(transcription_column, "")
330
 
331
  # Handle audio data format
332
+ # AudioDecoder objects need to use get_all_samples() to extract audio data
333
  audio_array = None
334
  audio_sr = target_sr
335
 
336
  if audio_data is not None:
337
+ # Case 1: Already decoded dict format (ideal case)
338
  if isinstance(audio_data, dict):
339
  audio_array = audio_data.get("array")
340
  audio_sr = audio_data.get("sampling_rate", target_sr)
341
+ # Case 2: AudioDecoder object - use get_all_samples() method
 
 
 
 
 
 
 
342
  elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
343
  try:
344
+ # AudioDecoder has get_all_samples() method that returns AudioSamples
345
+ if hasattr(audio_data, "get_all_samples"):
346
+ audio_samples = audio_data.get_all_samples()
347
+ # AudioSamples has .data (PyTorch tensor) and .sample_rate
348
+ if hasattr(audio_samples, "data"):
349
+ # Convert PyTorch tensor to numpy array
350
+ if hasattr(audio_samples.data, "numpy"):
351
+ audio_array = audio_samples.data.numpy()
352
+ elif hasattr(audio_samples.data, "cpu"):
353
+ # If it's on GPU, move to CPU first
354
+ audio_array = audio_samples.data.cpu().numpy()
355
+ else:
356
+ # Try to convert directly
357
+ audio_array = np.array(audio_samples.data)
358
+
359
+ # Handle multi-channel audio (take first channel if stereo)
360
+ if audio_array.ndim > 1:
361
+ audio_array = audio_array[0] if audio_array.shape[0] == 1 else audio_array.mean(axis=0)
362
+
363
+ # Get sampling rate
364
+ if hasattr(audio_samples, "sample_rate"):
365
+ audio_sr = audio_samples.sample_rate
366
+ elif hasattr(audio_samples, "sampling_rate"):
367
+ audio_sr = audio_samples.sampling_rate
368
+ else:
369
+ raise ValueError("AudioSamples object doesn't have 'data' attribute")
370
+ else:
371
+ raise ValueError("AudioDecoder doesn't have 'get_all_samples' method")
372
  except Exception as e:
373
  print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
374
  skipped_count += 1
375
  if skipped_count <= 5:
376
  print(f" Audio data type: {type(audio_data)}")
377
+ import traceback
378
+ print(f" Error details: {traceback.format_exc()}")
379
  continue
380
+ # Case 3: Audio object with .array attribute
381
+ elif hasattr(audio_data, "array"):
382
+ audio_array = audio_data.array
383
+ audio_sr = getattr(audio_data, "sampling_rate", target_sr)
384
+ # Case 4: Already a numpy array
385
+ elif isinstance(audio_data, np.ndarray):
386
+ audio_array = audio_data
387
  else:
388
  print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
389
  skipped_count += 1
390
  if skipped_count <= 5:
391
  print(f" Audio data type: {type(audio_data)}")
392
+ print(f" Available attributes: {dir(audio_data)[:10]}")
393
  continue
394
 
395
  if audio_array is None:
 
452
 
453
  # Progress update every 1000 samples
454
  if (idx + 1) % 1000 == 0:
455
+ print(f" Processed {idx + 1:,}/{len(dataset):,} samples...")
456
 
457
  if len(expanded_examples) == 0:
458
  raise ValueError(
 
468
  from datasets import Audio
469
  expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
470
 
471
+ print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset):.1f}x)")
472
  if skipped_count > 0:
473
  print(f"⚠ Skipped {skipped_count} samples during expansion")
474