C4G-HKUST commited on
Commit
6c41e4a
·
1 Parent(s): 0c6b95b

feat: trim 4s

Browse files
wan/audio2video_multiID.py CHANGED
@@ -199,7 +199,7 @@ class WanAF2V:
199
  audio_paths=None, # New: audio path list, supports multiple audio files
200
  task_key=None,
201
  mode="pad", # Audio processing mode: "pad" or "concat"
202
- trim_to_6s=False, # Fast mode: trim audio to 4 seconds
203
  ):
204
  r"""
205
  Generates video frames from input image and text prompt using diffusion process.
@@ -515,7 +515,7 @@ class WanAF2V:
515
  half_dtype=self.half_dtype,
516
  preprocess_audio=preprocess_audio,
517
  resample_audio=resample_audio,
518
- trim_to_6s=trim_to_6s,
519
  )
520
 
521
  # Prepare audio_ref_features - new list mode
 
199
  audio_paths=None, # New: audio path list, supports multiple audio files
200
  task_key=None,
201
  mode="pad", # Audio processing mode: "pad" or "concat"
202
+ trim_to_4s=False, # Fast mode: trim audio to 4 seconds
203
  ):
204
  r"""
205
  Generates video frames from input image and text prompt using diffusion process.
 
515
  half_dtype=self.half_dtype,
516
  preprocess_audio=preprocess_audio,
517
  resample_audio=resample_audio,
518
+ trim_to_4s=trim_to_4s,
519
  )
520
 
521
  # Prepare audio_ref_features - new list mode
wan/utils/infer_utils.py CHANGED
@@ -118,7 +118,7 @@ def process_audio_features(
118
  half_dtype=None,
119
  preprocess_audio=None,
120
  resample_audio=None,
121
- trim_to_6s=False, # Fast mode: trim audio to 4 seconds
122
  ):
123
  """
124
  Process audio files and extract audio features.
@@ -203,8 +203,8 @@ def process_audio_features(
203
  total_length = sum(audio_lengths)
204
  print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
205
 
206
- # Fast mode: trim to 4 seconds if trim_to_6s is True
207
- if trim_to_6s:
208
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
209
  max_frames_4s = 97
210
  if total_length > max_frames_4s:
@@ -281,7 +281,7 @@ def process_audio_features(
281
  audio_feat_list.append(zero_audio_feat)
282
  print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
283
  else:
284
- # Pad mode: keep existing logic, but apply trim_to_6s if needed
285
  for i, audio_path in enumerate(audio_paths):
286
  if audio_path and os.path.exists(audio_path):
287
  print(f"Processing audio {i}: {audio_path}")
@@ -294,9 +294,9 @@ def process_audio_features(
294
  with torch.no_grad():
295
  print(f"wav2vec_model: {wav2vec_model}")
296
  print(f"cache_dir:{cache_dir}")
297
- # Fast mode: if trim_to_6s, limit to 4 seconds
298
  target_frames = F
299
- if trim_to_6s:
300
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
301
  max_frames_4s = 97
302
  target_frames = min(F, max_frames_4s)
@@ -343,9 +343,9 @@ def process_audio_features(
343
  target_resampled_audio_path,
344
  )
345
  with torch.no_grad():
346
- # Fast mode: if trim_to_6s, limit to 4 seconds
347
  target_frames = F
348
- if trim_to_6s:
349
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
350
  max_frames_4s = 97
351
  target_frames = min(F, max_frames_4s)
 
118
  half_dtype=None,
119
  preprocess_audio=None,
120
  resample_audio=None,
121
+ trim_to_4s=False, # Fast mode: trim audio to 4 seconds
122
  ):
123
  """
124
  Process audio files and extract audio features.
 
203
  total_length = sum(audio_lengths)
204
  print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
205
 
206
+ # Fast mode: trim to 4 seconds if trim_to_4s is True
207
+ if trim_to_4s:
208
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
209
  max_frames_4s = 97
210
  if total_length > max_frames_4s:
 
281
  audio_feat_list.append(zero_audio_feat)
282
  print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
283
  else:
284
+ # Pad mode: keep existing logic, but apply trim_to_4s if needed
285
  for i, audio_path in enumerate(audio_paths):
286
  if audio_path and os.path.exists(audio_path):
287
  print(f"Processing audio {i}: {audio_path}")
 
294
  with torch.no_grad():
295
  print(f"wav2vec_model: {wav2vec_model}")
296
  print(f"cache_dir:{cache_dir}")
297
+ # Fast mode: if trim_to_4s, limit to 4 seconds
298
  target_frames = F
299
+ if trim_to_4s:
300
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
301
  max_frames_4s = 97
302
  target_frames = min(F, max_frames_4s)
 
343
  target_resampled_audio_path,
344
  )
345
  with torch.no_grad():
346
+ # Fast mode: if trim_to_4s, limit to 4 seconds
347
  target_frames = F
348
+ if trim_to_4s:
349
  # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
350
  max_frames_4s = 97
351
  target_frames = min(F, max_frames_4s)