Spaces:
Running
on
Zero
Running
on
Zero
feat: trim 4s
Browse files- wan/audio2video_multiID.py +2 -2
- wan/utils/infer_utils.py +8 -8
wan/audio2video_multiID.py
CHANGED
|
@@ -199,7 +199,7 @@ class WanAF2V:
|
|
| 199 |
audio_paths=None, # New: audio path list, supports multiple audio files
|
| 200 |
task_key=None,
|
| 201 |
mode="pad", # Audio processing mode: "pad" or "concat"
|
| 202 |
-
|
| 203 |
):
|
| 204 |
r"""
|
| 205 |
Generates video frames from input image and text prompt using diffusion process.
|
|
@@ -515,7 +515,7 @@ class WanAF2V:
|
|
| 515 |
half_dtype=self.half_dtype,
|
| 516 |
preprocess_audio=preprocess_audio,
|
| 517 |
resample_audio=resample_audio,
|
| 518 |
-
|
| 519 |
)
|
| 520 |
|
| 521 |
# Prepare audio_ref_features - new list mode
|
|
|
|
| 199 |
audio_paths=None, # New: audio path list, supports multiple audio files
|
| 200 |
task_key=None,
|
| 201 |
mode="pad", # Audio processing mode: "pad" or "concat"
|
| 202 |
+
trim_to_4s=False, # Fast mode: trim audio to 4 seconds
|
| 203 |
):
|
| 204 |
r"""
|
| 205 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
|
| 515 |
half_dtype=self.half_dtype,
|
| 516 |
preprocess_audio=preprocess_audio,
|
| 517 |
resample_audio=resample_audio,
|
| 518 |
+
trim_to_4s=trim_to_4s,
|
| 519 |
)
|
| 520 |
|
| 521 |
# Prepare audio_ref_features - new list mode
|
wan/utils/infer_utils.py
CHANGED
|
@@ -118,7 +118,7 @@ def process_audio_features(
|
|
| 118 |
half_dtype=None,
|
| 119 |
preprocess_audio=None,
|
| 120 |
resample_audio=None,
|
| 121 |
-
|
| 122 |
):
|
| 123 |
"""
|
| 124 |
Process audio files and extract audio features.
|
|
@@ -203,8 +203,8 @@ def process_audio_features(
|
|
| 203 |
total_length = sum(audio_lengths)
|
| 204 |
print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
|
| 205 |
|
| 206 |
-
# Fast mode: trim to 4 seconds if
|
| 207 |
-
if
|
| 208 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 209 |
max_frames_4s = 97
|
| 210 |
if total_length > max_frames_4s:
|
|
@@ -281,7 +281,7 @@ def process_audio_features(
|
|
| 281 |
audio_feat_list.append(zero_audio_feat)
|
| 282 |
print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
|
| 283 |
else:
|
| 284 |
-
# Pad mode: keep existing logic, but apply
|
| 285 |
for i, audio_path in enumerate(audio_paths):
|
| 286 |
if audio_path and os.path.exists(audio_path):
|
| 287 |
print(f"Processing audio {i}: {audio_path}")
|
|
@@ -294,9 +294,9 @@ def process_audio_features(
|
|
| 294 |
with torch.no_grad():
|
| 295 |
print(f"wav2vec_model: {wav2vec_model}")
|
| 296 |
print(f"cache_dir:{cache_dir}")
|
| 297 |
-
# Fast mode: if
|
| 298 |
target_frames = F
|
| 299 |
-
if
|
| 300 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 301 |
max_frames_4s = 97
|
| 302 |
target_frames = min(F, max_frames_4s)
|
|
@@ -343,9 +343,9 @@ def process_audio_features(
|
|
| 343 |
target_resampled_audio_path,
|
| 344 |
)
|
| 345 |
with torch.no_grad():
|
| 346 |
-
# Fast mode: if
|
| 347 |
target_frames = F
|
| 348 |
-
if
|
| 349 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 350 |
max_frames_4s = 97
|
| 351 |
target_frames = min(F, max_frames_4s)
|
|
|
|
| 118 |
half_dtype=None,
|
| 119 |
preprocess_audio=None,
|
| 120 |
resample_audio=None,
|
| 121 |
+
trim_to_4s=False, # Fast mode: trim audio to 4 seconds
|
| 122 |
):
|
| 123 |
"""
|
| 124 |
Process audio files and extract audio features.
|
|
|
|
| 203 |
total_length = sum(audio_lengths)
|
| 204 |
print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
|
| 205 |
|
| 206 |
+
# Fast mode: trim to 4 seconds if trim_to_4s is True
|
| 207 |
+
if trim_to_4s:
|
| 208 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 209 |
max_frames_4s = 97
|
| 210 |
if total_length > max_frames_4s:
|
|
|
|
| 281 |
audio_feat_list.append(zero_audio_feat)
|
| 282 |
print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
|
| 283 |
else:
|
| 284 |
+
# Pad mode: keep existing logic, but apply trim_to_4s if needed
|
| 285 |
for i, audio_path in enumerate(audio_paths):
|
| 286 |
if audio_path and os.path.exists(audio_path):
|
| 287 |
print(f"Processing audio {i}: {audio_path}")
|
|
|
|
| 294 |
with torch.no_grad():
|
| 295 |
print(f"wav2vec_model: {wav2vec_model}")
|
| 296 |
print(f"cache_dir:{cache_dir}")
|
| 297 |
+
# Fast mode: if trim_to_4s, limit to 4 seconds
|
| 298 |
target_frames = F
|
| 299 |
+
if trim_to_4s:
|
| 300 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 301 |
max_frames_4s = 97
|
| 302 |
target_frames = min(F, max_frames_4s)
|
|
|
|
| 343 |
target_resampled_audio_path,
|
| 344 |
)
|
| 345 |
with torch.no_grad():
|
| 346 |
+
# Fast mode: if trim_to_4s, limit to 4 seconds
|
| 347 |
target_frames = F
|
| 348 |
+
if trim_to_4s:
|
| 349 |
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
|
| 350 |
max_frames_4s = 97
|
| 351 |
target_frames = min(F, max_frames_4s)
|