Fix audio decoding: use dataset.decode() to force decode all audio before expansion
Browse files- training/augmentation.py +43 -21
training/augmentation.py
CHANGED
|
@@ -307,45 +307,67 @@ def expand_dataset_with_speed_augmentation(
|
|
| 307 |
print(f" Original size: {len(dataset):,} samples")
|
| 308 |
print(f" Speed factors: {speed_factors}")
|
| 309 |
|
| 310 |
-
#
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
original_id = example.get(id_column, f"sample_{idx}")
|
|
|
|
|
|
|
| 314 |
audio_data = example.get(audio_column)
|
| 315 |
transcription = example.get(transcription_column, "")
|
| 316 |
|
| 317 |
-
# Handle audio data format
|
|
|
|
| 318 |
audio_array = None
|
| 319 |
audio_sr = target_sr
|
| 320 |
|
| 321 |
if audio_data is not None:
|
| 322 |
-
# Handle dict format (most common - audio is decoded when accessed via dataset[idx])
|
| 323 |
if isinstance(audio_data, dict):
|
| 324 |
audio_array = audio_data.get("array")
|
| 325 |
audio_sr = audio_data.get("sampling_rate", target_sr)
|
| 326 |
-
#
|
| 327 |
elif hasattr(audio_data, "array"):
|
| 328 |
audio_array = audio_data.array
|
| 329 |
audio_sr = getattr(audio_data, "sampling_rate", target_sr)
|
| 330 |
-
#
|
| 331 |
elif isinstance(audio_data, np.ndarray):
|
| 332 |
audio_array = audio_data
|
| 333 |
-
# If
|
| 334 |
elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
|
| 335 |
-
# Try to get the decoded value - sometimes accessing via dict key helps
|
| 336 |
try:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
audio_sr = getattr(audio_data, "sampling_rate", target_sr)
|
| 346 |
except Exception as e:
|
| 347 |
-
print(f"⚠ Warning: Failed to decode
|
|
|
|
|
|
|
|
|
|
| 348 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
if audio_array is None:
|
| 351 |
skipped_count += 1
|
|
@@ -407,7 +429,7 @@ def expand_dataset_with_speed_augmentation(
|
|
| 407 |
|
| 408 |
# Progress update every 1000 samples
|
| 409 |
if (idx + 1) % 1000 == 0:
|
| 410 |
-
print(f" Processed {idx + 1:,}/{len(
|
| 411 |
|
| 412 |
if len(expanded_examples) == 0:
|
| 413 |
raise ValueError(
|
|
@@ -423,7 +445,7 @@ def expand_dataset_with_speed_augmentation(
|
|
| 423 |
from datasets import Audio
|
| 424 |
expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
|
| 425 |
|
| 426 |
-
print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(
|
| 427 |
if skipped_count > 0:
|
| 428 |
print(f"⚠ Skipped {skipped_count} samples during expansion")
|
| 429 |
|
|
|
|
| 307 |
print(f" Original size: {len(dataset):,} samples")
|
| 308 |
print(f" Speed factors: {speed_factors}")
|
| 309 |
|
| 310 |
+
# Force decode all audio files before iterating
|
| 311 |
+
# This ensures AudioDecoder objects are decoded to dict format with 'array' and 'sampling_rate'
|
| 312 |
+
print(" Decoding audio files...")
|
| 313 |
+
try:
|
| 314 |
+
# Use decode() to force decoding of all audio files
|
| 315 |
+
# This converts AudioDecoder objects to dict format
|
| 316 |
+
dataset_decoded = dataset.decode()
|
| 317 |
+
print(f" ✓ Audio decoding complete")
|
| 318 |
+
except Exception as e:
|
| 319 |
+
print(f" ⚠ Warning: Failed to decode dataset: {e}")
|
| 320 |
+
print(f" Continuing with undecoded dataset (may be slower)...")
|
| 321 |
+
dataset_decoded = dataset
|
| 322 |
+
|
| 323 |
+
# Use indexed access to iterate over decoded dataset
|
| 324 |
+
for idx in range(len(dataset_decoded)):
|
| 325 |
+
example = dataset_decoded[idx]
|
| 326 |
original_id = example.get(id_column, f"sample_{idx}")
|
| 327 |
+
|
| 328 |
+
# After decode(), audio should be a dict with 'array' and 'sampling_rate'
|
| 329 |
audio_data = example.get(audio_column)
|
| 330 |
transcription = example.get(transcription_column, "")
|
| 331 |
|
| 332 |
+
# Handle audio data format
|
| 333 |
+
# After decode(), audio_data should be a dict with 'array' and 'sampling_rate'
|
| 334 |
audio_array = None
|
| 335 |
audio_sr = target_sr
|
| 336 |
|
| 337 |
if audio_data is not None:
|
|
|
|
| 338 |
if isinstance(audio_data, dict):
|
| 339 |
audio_array = audio_data.get("array")
|
| 340 |
audio_sr = audio_data.get("sampling_rate", target_sr)
|
| 341 |
+
# Fallback: handle Audio object with .array attribute (in case decode() didn't work)
|
| 342 |
elif hasattr(audio_data, "array"):
|
| 343 |
audio_array = audio_data.array
|
| 344 |
audio_sr = getattr(audio_data, "sampling_rate", target_sr)
|
| 345 |
+
# Fallback: handle numpy array directly
|
| 346 |
elif isinstance(audio_data, np.ndarray):
|
| 347 |
audio_array = audio_data
|
| 348 |
+
# If still an AudioDecoder, try to decode it manually
|
| 349 |
elif hasattr(audio_data, "__class__") and "Decoder" in str(type(audio_data)):
|
|
|
|
| 350 |
try:
|
| 351 |
+
if hasattr(audio_data, "decode"):
|
| 352 |
+
decoded = audio_data.decode()
|
| 353 |
+
if isinstance(decoded, dict):
|
| 354 |
+
audio_array = decoded.get("array")
|
| 355 |
+
audio_sr = decoded.get("sampling_rate", target_sr)
|
| 356 |
+
elif hasattr(decoded, "array"):
|
| 357 |
+
audio_array = decoded.array
|
| 358 |
+
audio_sr = getattr(decoded, "sampling_rate", target_sr)
|
|
|
|
| 359 |
except Exception as e:
|
| 360 |
+
print(f"⚠ Warning: Failed to decode AudioDecoder for sample {original_id}: {e}, skipping...")
|
| 361 |
+
skipped_count += 1
|
| 362 |
+
if skipped_count <= 5:
|
| 363 |
+
print(f" Audio data type: {type(audio_data)}")
|
| 364 |
continue
|
| 365 |
+
else:
|
| 366 |
+
print(f"⚠ Warning: Unexpected audio format for sample {original_id}: {type(audio_data)}, skipping...")
|
| 367 |
+
skipped_count += 1
|
| 368 |
+
if skipped_count <= 5:
|
| 369 |
+
print(f" Audio data type: {type(audio_data)}")
|
| 370 |
+
continue
|
| 371 |
|
| 372 |
if audio_array is None:
|
| 373 |
skipped_count += 1
|
|
|
|
| 429 |
|
| 430 |
# Progress update every 1000 samples
|
| 431 |
if (idx + 1) % 1000 == 0:
|
| 432 |
+
print(f" Processed {idx + 1:,}/{len(dataset_decoded):,} samples...")
|
| 433 |
|
| 434 |
if len(expanded_examples) == 0:
|
| 435 |
raise ValueError(
|
|
|
|
| 445 |
from datasets import Audio
|
| 446 |
expanded_dataset = expanded_dataset.cast_column(audio_column, Audio(sampling_rate=target_sr))
|
| 447 |
|
| 448 |
+
print(f"✓ Expanded dataset: {len(expanded_dataset):,} samples ({len(expanded_dataset) / len(dataset_decoded):.1f}x)")
|
| 449 |
if skipped_count > 0:
|
| 450 |
print(f"⚠ Skipped {skipped_count} samples during expansion")
|
| 451 |
|