C4G-HKUST commited on
Commit
c2db27b
·
1 Parent(s): f2436d3

Fix VAE scale tensor device mismatch: recreate scale list on GPU and move T5 encoder

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -620,12 +620,26 @@ def run_graio_demo(args):
620
  wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
621
  if hasattr(wan_a2v.vae, 'std'):
622
  wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
 
 
 
 
 
 
623
 
624
  # 移动 CLIP 模型到 GPU
625
  if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
626
  if hasattr(wan_a2v.clip, 'model'):
627
  wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
628
 
 
 
 
 
 
 
 
 
629
  # 更新设备信息
630
  wan_a2v.device = cuda_device
631
 
 
620
  wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
621
  if hasattr(wan_a2v.vae, 'std'):
622
  wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
623
+ # 重新创建 scale 列表,确保在 GPU 上
624
+ if hasattr(wan_a2v.vae, 'mean') and hasattr(wan_a2v.vae, 'std'):
625
+ wan_a2v.vae.scale = [wan_a2v.vae.mean, 1.0 / wan_a2v.vae.std]
626
+ # 更新 VAE 的设备属性
627
+ if hasattr(wan_a2v.vae, 'device'):
628
+ wan_a2v.vae.device = cuda_device
629
 
630
  # 移动 CLIP 模型到 GPU
631
  if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
632
  if hasattr(wan_a2v.clip, 'model'):
633
  wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
634
 
635
+ # 移动 T5 encoder 到 GPU(如果不在 CPU 上)
636
+ if hasattr(wan_a2v, 'text_encoder') and wan_a2v.text_encoder is not None:
637
+ if hasattr(wan_a2v.text_encoder, 'model'):
638
+ wan_a2v.text_encoder.model = wan_a2v.text_encoder.model.to(cuda_device)
639
+ # 更新 T5 encoder 的设备属性
640
+ if hasattr(wan_a2v.text_encoder, 'device'):
641
+ wan_a2v.text_encoder.device = cuda_device
642
+
643
  # 更新设备信息
644
  wan_a2v.device = cuda_device
645