SlouchyBuffalo commited on
Commit
7567340
·
verified ·
1 Parent(s): 68abd31

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -186
app.py DELETED
@@ -1,186 +0,0 @@
1
- import spaces
2
- import gradio as gr
3
- from sd3_pipeline import StableDiffusion3Pipeline
4
- import torch
5
- import random
6
- import numpy as np
7
- import os
8
- import gc
9
- from diffusers import AutoencoderKLWan
10
- from wan_pipeline import WanPipeline
11
- from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
12
- from PIL import Image
13
- from diffusers.utils import export_to_video
14
- from huggingface_hub import login
15
-
16
- # Authenticate with HF
17
- login(token=os.getenv('HF_TOKEN'))
18
-
19
- def set_seed(seed):
20
- random.seed(seed)
21
- os.environ['PYTHONHASHSEED'] = str(seed)
22
- np.random.seed(seed)
23
- torch.manual_seed(seed)
24
- torch.cuda.manual_seed(seed)
25
-
26
- model_paths = {
27
- "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
28
- "sd3.5": "stabilityai/stable-diffusion-3.5-large",
29
- # "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
30
- }
31
-
32
- current_model = None
33
- OUTPUT_DIR = "generated_videos"
34
- os.makedirs(OUTPUT_DIR, exist_ok=True)
35
-
36
- def load_model(model_name):
37
- global current_model
38
- if current_model is not None:
39
- del current_model
40
- torch.cuda.empty_cache()
41
- gc.collect()
42
-
43
- if "wan-t2v" in model_name:
44
- vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.bfloat16)
45
- scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
46
- current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.float16).to("cuda")
47
- current_model.scheduler = scheduler
48
- else:
49
- current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
50
-
51
- return current_model.to("cuda")
52
-
53
- @spaces.GPU(duration=120)
54
- def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50,
55
- use_cfg_zero_star=True, use_zero_init=True, zero_steps=0,
56
- seed=None, compare_mode=False):
57
-
58
- model = load_model(model_name)
59
- if seed is None:
60
- seed = random.randint(0, 2**32 - 1)
61
- set_seed(seed)
62
-
63
- is_video_model = "wan-t2v" in model_name
64
- print('prompt: ',prompt)
65
- if is_video_model:
66
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
67
- video1_frames = model(
68
- prompt=prompt,
69
- negative_prompt=negative_prompt,
70
- height=480,
71
- width=832,
72
- num_frames=81,
73
- num_inference_steps=num_inference_steps,
74
- guidance_scale=guidance_scale,
75
- use_cfg_zero_star=True,
76
- use_zero_init=True,
77
- zero_steps=zero_steps
78
- ).frames[0]
79
- video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4")
80
- export_to_video(video1_frames, video1_path, fps=16)
81
-
82
- return None, None, video1_path, seed
83
-
84
- if compare_mode:
85
- set_seed(seed)
86
- image1 = model(
87
- prompt,
88
- guidance_scale=guidance_scale,
89
- num_inference_steps=num_inference_steps,
90
- use_cfg_zero_star=True,
91
- use_zero_init=use_zero_init,
92
- zero_steps=zero_steps
93
- ).images[0]
94
-
95
- set_seed(seed)
96
- image2 = model(
97
- prompt,
98
- guidance_scale=guidance_scale,
99
- num_inference_steps=num_inference_steps,
100
- use_cfg_zero_star=False,
101
- use_zero_init=use_zero_init,
102
- zero_steps=zero_steps
103
- ).images[0]
104
-
105
- return image1, image2, seed
106
- else:
107
- image = model(
108
- prompt,
109
- guidance_scale=guidance_scale,
110
- num_inference_steps=num_inference_steps,
111
- use_cfg_zero_star=use_cfg_zero_star,
112
- use_zero_init=use_zero_init,
113
- zero_steps=zero_steps
114
- ).images[0]
115
-
116
- if use_cfg_zero_star:
117
- return image, None, seed
118
- else:
119
- return None, image, seed
120
-
121
- # Gradio UI with left-right layout
122
- with gr.Blocks() as demo:
123
- gr.HTML("""
124
- <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
125
- CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models
126
- </div>
127
- <div style="text-align: center;">
128
- <a href="https://github.com/WeichenFan/CFG-Zero-star">Code</a> |
129
- <a href="https://arxiv.org/abs/2503.18886">Paper</a>
130
- </div>
131
- """)
132
-
133
- with gr.Row():
134
- with gr.Column(scale=1):
135
- prompt = gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt")
136
- model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model")
137
- guidance_scale = gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale")
138
- inference_steps = gr.Slider(10, 100, value=50, step=5, label="Inference Steps")
139
- use_opt_scale = gr.Checkbox(value=True, label="Use Optimized-Scale")
140
- use_zero_init = gr.Checkbox(value=True, label="Use Zero Init")
141
- zero_steps = gr.Slider(0, 20, value=1, step=1, label="Zero out steps")
142
- seed = gr.Number(value=42, label="Seed (Leave blank for random)")
143
- compare_mode = gr.Checkbox(value=True, label="Compare Mode")
144
- generate_btn = gr.Button("Generate")
145
-
146
- with gr.Column(scale=2):
147
- out1 = gr.Image(type="pil", label="CFG-Zero* Image")
148
- out2 = gr.Image(type="pil", label="CFG Image")
149
- #video = gr.Video(label="Video")
150
- used_seed = gr.Textbox(label="Used Seed")
151
-
152
- def update_params(model_name):
153
- print('model_name: ',model_name)
154
- if model_name == "wan-t2v":
155
- return (
156
- gr.update(value=5),
157
- gr.update(value=50),
158
- gr.update(value=True),
159
- gr.update(value=True),
160
- gr.update(value=1)
161
- )
162
- else:
163
- return (
164
- gr.update(value=4.0),
165
- gr.update(value=50),
166
- gr.update(value=True),
167
- gr.update(value=True),
168
- gr.update(value=1)
169
- )
170
-
171
- model_choice.change(
172
- fn=update_params,
173
- inputs=[model_choice],
174
- outputs=[guidance_scale, inference_steps, use_opt_scale, use_zero_init, zero_steps]
175
- )
176
-
177
- generate_btn.click(
178
- fn=generate_content,
179
- inputs=[
180
- prompt, model_choice, guidance_scale, inference_steps,
181
- use_opt_scale, use_zero_init, zero_steps, seed, compare_mode
182
- ],
183
- outputs=[out1, out2, used_seed]
184
- )
185
-
186
- demo.launch(ssr_mode=False)