Spaces:
Running
on
Zero
Running
on
Zero
Fix VAE scale tensor device mismatch: recreate scale list on GPU and move T5 encoder
Browse files
app.py
CHANGED
|
@@ -620,12 +620,26 @@ def run_graio_demo(args):
|
|
| 620 |
wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
|
| 621 |
if hasattr(wan_a2v.vae, 'std'):
|
| 622 |
wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
# 移动 CLIP 模型到 GPU
|
| 625 |
if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
|
| 626 |
if hasattr(wan_a2v.clip, 'model'):
|
| 627 |
wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
|
| 628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
# 更新设备信息
|
| 630 |
wan_a2v.device = cuda_device
|
| 631 |
|
|
|
|
| 620 |
wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
|
| 621 |
if hasattr(wan_a2v.vae, 'std'):
|
| 622 |
wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
|
| 623 |
+
# 重新创建 scale 列表,确保在 GPU 上
|
| 624 |
+
if hasattr(wan_a2v.vae, 'mean') and hasattr(wan_a2v.vae, 'std'):
|
| 625 |
+
wan_a2v.vae.scale = [wan_a2v.vae.mean, 1.0 / wan_a2v.vae.std]
|
| 626 |
+
# 更新 VAE 的设备属性
|
| 627 |
+
if hasattr(wan_a2v.vae, 'device'):
|
| 628 |
+
wan_a2v.vae.device = cuda_device
|
| 629 |
|
| 630 |
# 移动 CLIP 模型到 GPU
|
| 631 |
if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
|
| 632 |
if hasattr(wan_a2v.clip, 'model'):
|
| 633 |
wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
|
| 634 |
|
| 635 |
+
# 移动 T5 encoder 到 GPU(如果不在 CPU 上)
|
| 636 |
+
if hasattr(wan_a2v, 'text_encoder') and wan_a2v.text_encoder is not None:
|
| 637 |
+
if hasattr(wan_a2v.text_encoder, 'model'):
|
| 638 |
+
wan_a2v.text_encoder.model = wan_a2v.text_encoder.model.to(cuda_device)
|
| 639 |
+
# 更新 T5 encoder 的设备属性
|
| 640 |
+
if hasattr(wan_a2v.text_encoder, 'device'):
|
| 641 |
+
wan_a2v.text_encoder.device = cuda_device
|
| 642 |
+
|
| 643 |
# 更新设备信息
|
| 644 |
wan_a2v.device = cuda_device
|
| 645 |
|