Fix AudioDecoder handling: use get_all_samples() method to extract audio data from AudioDecoder objects
Browse files- training/augmentation.py +60 -37
training/augmentation.py
CHANGED
|
@@ -307,66 +307,89 @@ def expand_dataset_with_speed_augmentation(
|
|
| 307 |
print(f" Original size: {len(dataset):,} samples")
|
| 308 |
print(f" Speed factors: {speed_factors}")
|
| 309 |
|
| 310 |
-
#
|
| 311 |
-
#
|
| 312 |
-
print("
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
print(f" ✓ Audio decoding complete")
|
| 318 |
-
except Exception as e:
|
| 319 |
-
print(f" ⚠ Warning: Failed to decode dataset: {e}")
|
| 320 |
-
print(f" Continuing with undecoded dataset (may be slower)...")
|
| 321 |
-
dataset_decoded = dataset
|
| 322 |
-
|
| 323 |
-
# Use indexed access to iterate over decoded dataset
|
| 324 |
-
for idx in range(len(dataset_decoded)):
|
| 325 |
-
example = dataset_decoded[idx]
|
| 326 |
original_id = example.get(id_column, f"sample_{idx}")
|
| 327 |
|
| 328 |
-
#
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
transcription = example.get(transcription_column, "")
|
| 331 |
|
| 332 |
# Handle audio data format
|
| 333 |
-
#
|
| 334 |
audio_array = None
|
| 335 |
audio_sr = target_sr
|
| 336 |
|
| 337 |
if audio_data is not None:
|
|
|
|
| 338 |
if isinstance(audio_data, dict):
|
| 339 |
audio_array = audio_data.get("array")
|
| 340 |
audio_sr = audio_data.get("sampling_rate", target_sr)
|
| 341 |
-
#
|
| 342 |
-
elif hasattr(audio_data, "array"):
|
| 343 |
-
audio_array = audio_data.array
|
| 344 |
-
audio_sr = getattr(audio_data, "sampling_rate", target_sr)
|
| 345 |
-
# Fallback: handle numpy array directly
|
| 346 |
-
elif isinstance(audio_data, np.ndarray):
|
| 347 |
-
audio_array = audio_data
|
| 348 |
-
# If still an AudioDecoder, try to decode it manually
|
| 349 |
elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
|
| 350 |
try:
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
except Exception as e:
|
| 360 |
print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
|
| 361 |
skipped_count += 1
|
| 362 |
if skipped_count <= 5:
|
| 363 |
print(f" Audio data type: {type(audio_data)}")
|
|
|
|
|
|
|
| 364 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
else:
|
| 366 |
print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
|
| 367 |
skipped_count += 1
|
| 368 |
if skipped_count <= 5:
|
| 369 |
print(f" Audio data type: {type(audio_data)}")
|
|
|
|
| 370 |
continue
|
| 371 |
|
| 372 |
if audio_array is None:
|
|
@@ -429,7 +452,7 @@ def expand_dataset_with_speed_augmentation(
|
|
| 429 |
|
| 430 |
# Progress update every 1000 samples
|
| 431 |
if (idx + 1) % 1000 == 0:
|
| 432 |
-
print(f" Processed {idx + 1:,}/{len(
|
| 433 |
|
| 434 |
if len(expanded_examples) == 0:
|
| 435 |
raise ValueError(
|
|
@@ -445,7 +468,7 @@ def expand_dataset_with_speed_augmentation(
|
|
| 445 |
from datasets import Audio
|
| 446 |
expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
|
| 447 |
|
| 448 |
-
print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(
|
| 449 |
if skipped_count > 0:
|
| 450 |
print(f"⚠ Skipped {skipped_count} samples during expansion")
|
| 451 |
|
|
|
|
| 307 |
print(f" Original size: {len(dataset):,} samples")
|
| 308 |
print(f" Speed factors: {speed_factors}")
|
| 309 |
|
| 310 |
+
# Access audio column directly using bracket notation to trigger decoding
|
| 311 |
+
# HuggingFace datasets automatically decode AudioDecoder when accessed via bracket notation
|
| 312 |
+
print(" Processing audio files (decoding on-the-fly)...")
|
| 313 |
+
|
| 314 |
+
# Use indexed access to iterate over dataset
|
| 315 |
+
for idx in range(len(dataset)):
|
| 316 |
+
example = dataset[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
original_id = example.get(id_column, f"sample_{idx}")
|
| 318 |
|
| 319 |
+
# Access audio column using bracket notation - this should trigger automatic decoding
|
| 320 |
+
# Try bracket access first (triggers decoding), fallback to .get() if needed
|
| 321 |
+
try:
|
| 322 |
+
audio_data = example[audio_column] # Bracket access triggers decoding
|
| 323 |
+
except (KeyError, TypeError):
|
| 324 |
+
try:
|
| 325 |
+
audio_data = example.get(audio_column)
|
| 326 |
+
except:
|
| 327 |
+
audio_data = None
|
| 328 |
+
|
| 329 |
transcription = example.get(transcription_column, "")
|
| 330 |
|
| 331 |
# Handle audio data format
|
| 332 |
+
# AudioDecoder objects need to use get_all_samples() to extract audio data
|
| 333 |
audio_array = None
|
| 334 |
audio_sr = target_sr
|
| 335 |
|
| 336 |
if audio_data is not None:
|
| 337 |
+
# Case 1: Already decoded dict format (ideal case)
|
| 338 |
if isinstance(audio_data, dict):
|
| 339 |
audio_array = audio_data.get("array")
|
| 340 |
audio_sr = audio_data.get("sampling_rate", target_sr)
|
| 341 |
+
# Case 2: AudioDecoder object - use get_all_samples() method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
|
| 343 |
try:
|
| 344 |
+
# AudioDecoder has get_all_samples() method that returns AudioSamples
|
| 345 |
+
if hasattr(audio_data, "get_all_samples"):
|
| 346 |
+
audio_samples = audio_data.get_all_samples()
|
| 347 |
+
# AudioSamples has .data (PyTorch tensor) and .sample_rate
|
| 348 |
+
if hasattr(audio_samples, "data"):
|
| 349 |
+
# Convert PyTorch tensor to numpy array
|
| 350 |
+
if hasattr(audio_samples.data, "numpy"):
|
| 351 |
+
audio_array = audio_samples.data.numpy()
|
| 352 |
+
elif hasattr(audio_samples.data, "cpu"):
|
| 353 |
+
# If it's on GPU, move to CPU first
|
| 354 |
+
audio_array = audio_samples.data.cpu().numpy()
|
| 355 |
+
else:
|
| 356 |
+
# Try to convert directly
|
| 357 |
+
audio_array = np.array(audio_samples.data)
|
| 358 |
+
|
| 359 |
+
# Handle multi-channel audio (take first channel if stereo)
|
| 360 |
+
if audio_array.ndim > 1:
|
| 361 |
+
audio_array = audio_array[0] if audio_array.shape[0] == 1 else audio_array.mean(axis=0)
|
| 362 |
+
|
| 363 |
+
# Get sampling rate
|
| 364 |
+
if hasattr(audio_samples, "sample_rate"):
|
| 365 |
+
audio_sr = audio_samples.sample_rate
|
| 366 |
+
elif hasattr(audio_samples, "sampling_rate"):
|
| 367 |
+
audio_sr = audio_samples.sampling_rate
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError("AudioSamples object doesn't have 'data' attribute")
|
| 370 |
+
else:
|
| 371 |
+
raise ValueError("AudioDecoder doesn't have 'get_all_samples' method")
|
| 372 |
except Exception as e:
|
| 373 |
print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
|
| 374 |
skipped_count += 1
|
| 375 |
if skipped_count <= 5:
|
| 376 |
print(f" Audio data type: {type(audio_data)}")
|
| 377 |
+
import traceback
|
| 378 |
+
print(f" Error details: {traceback.format_exc()}")
|
| 379 |
continue
|
| 380 |
+
# Case 3: Audio object with .array attribute
|
| 381 |
+
elif hasattr(audio_data, "array"):
|
| 382 |
+
audio_array = audio_data.array
|
| 383 |
+
audio_sr = getattr(audio_data, "sampling_rate", target_sr)
|
| 384 |
+
# Case 4: Already a numpy array
|
| 385 |
+
elif isinstance(audio_data, np.ndarray):
|
| 386 |
+
audio_array = audio_data
|
| 387 |
else:
|
| 388 |
print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
|
| 389 |
skipped_count += 1
|
| 390 |
if skipped_count <= 5:
|
| 391 |
print(f" Audio data type: {type(audio_data)}")
|
| 392 |
+
print(f" Available attributes: {dir(audio_data)[:10]}")
|
| 393 |
continue
|
| 394 |
|
| 395 |
if audio_array is None:
|
|
|
|
| 452 |
|
| 453 |
# Progress update every 1000 samples
|
| 454 |
if (idx + 1) % 1000 == 0:
|
| 455 |
+
print(f" Processed {idx + 1:,}/{len(dataset):,} samples...")
|
| 456 |
|
| 457 |
if len(expanded_examples) == 0:
|
| 458 |
raise ValueError(
|
|
|
|
| 468 |
from datasets import Audio
|
| 469 |
expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
|
| 470 |
|
| 471 |
+
print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset):.1f}x)")
|
| 472 |
if skipped_count > 0:
|
| 473 |
print(f"⚠ Skipped {skipped_count} samples during expansion")
|
| 474 |
|