SlouchyBuffalo commited on
Commit
8e9c347
·
verified ·
1 Parent(s): 7567340

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from sd3_pipeline import StableDiffusion3Pipeline
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ import os
8
+ import gc
9
+ from diffusers import AutoencoderKLWan
10
+ from wan_pipeline import WanPipeline
11
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
12
+ from PIL import Image
13
+ from diffusers.utils import export_to_video
14
+ from huggingface_hub import login
15
+
16
+ # Authenticate with HF
17
+ login(token=os.getenv('HF_TOKEN'))
18
+
19
+ def set_seed(seed):
20
+ random.seed(seed)
21
+ os.environ['PYTHONHASHSEED'] = str(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed(seed)
26
+
27
+ # Updated model paths - now includes gated models
28
+ model_paths = {
29
+ "sd2.1": "stabilityai/stable-diffusion-2-1",
30
+ "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
31
+ "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
32
+ "sd3.5": "stabilityai/stable-diffusion-3.5-large",
33
+ # "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" # Keep commented if you don't have access to this one
34
+ }
35
+
36
+ current_model = None
37
+ OUTPUT_DIR = "generated_videos"
38
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
39
+
40
+ def load_model(model_name):
41
+ global current_model
42
+ if current_model is not None:
43
+ del current_model
44
+ if torch.cuda.is_available():
45
+ torch.cuda.empty_cache()
46
+ gc.collect()
47
+
48
+ # Determine device
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ if "wan-t2v" in model_name:
52
+ vae = AutoencoderKLWan.from_pretrained(
53
+ model_paths[model_name],
54
+ subfolder="vae",
55
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
56
+ )
57
+ scheduler = UniPCMultistepScheduler(
58
+ prediction_type='flow_prediction',
59
+ use_flow_sigmas=True,
60
+ num_train_timesteps=1000,
61
+ flow_shift=8.0
62
+ )
63
+ current_model = WanPipeline.from_pretrained(
64
+ model_paths[model_name],
65
+ vae=vae,
66
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
67
+ ).to(device)
68
+ current_model.scheduler = scheduler
69
+ else:
70
+ # Handle different model types
71
+ if model_name in ["sd2.1"]:
72
+ from diffusers import StableDiffusionPipeline
73
+ current_model = StableDiffusionPipeline.from_pretrained(
74
+ model_paths[model_name],
75
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
76
+ ).to(device)
77
+ elif model_name in ["sdxl"]:
78
+ from diffusers import StableDiffusionXLPipeline
79
+ current_model = StableDiffusionXLPipeline.from_pretrained(
80
+ model_paths[model_name],
81
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
82
+ ).to(device)
83
+ else:
84
+ # For SD3 models (when access is granted)
85
+ current_model = StableDiffusion3Pipeline.from_pretrained(
86
+ model_paths[model_name],
87
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
88
+ ).to(device)
89
+
90
+ return current_model
91
+
92
+ @spaces.GPU(duration=120)
93
+ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50,
94
+ use_cfg_zero_star=True, use_zero_init=True, zero_steps=0,
95
+ seed=None, compare_mode=False):
96
+
97
+ model = load_model(model_name)
98
+ if seed is None:
99
+ seed = random.randint(0, 2**32 - 1)
100
+ set_seed(seed)
101
+
102
+ is_video_model = "wan-t2v" in model_name
103
+ print('prompt: ', prompt)
104
+
105
+ if is_video_model:
106
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
107
+ video1_frames = model(
108
+ prompt=prompt,
109
+ negative_prompt=negative_prompt,
110
+ height=480,
111
+ width=832,
112
+ num_frames=81,
113
+ num_inference_steps=num_inference_steps,
114
+ guidance_scale=guidance_scale,
115
+ use_cfg_zero_star=True,
116
+ use_zero_init=True,
117
+ zero_steps=zero_steps
118
+ ).frames[0]
119
+ video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4")
120
+ export_to_video(video1_frames, video1_path, fps=16)
121
+
122
+ return None, None, video1_path, seed
123
+
124
+ # Handle different model types for image generation
125
+ if model_name in ["sd2.1", "sdxl"]:
126
+ # Standard diffusers pipeline interface
127
+ if compare_mode:
128
+ set_seed(seed)
129
+ image1 = model(
130
+ prompt,
131
+ guidance_scale=guidance_scale,
132
+ num_inference_steps=num_inference_steps,
133
+ ).images[0]
134
+
135
+ set_seed(seed)
136
+ image2 = model(
137
+ prompt,
138
+ guidance_scale=guidance_scale,
139
+ num_inference_steps=num_inference_steps,
140
+ ).images[0]
141
+
142
+ return image1, image2, None, seed
143
+ else:
144
+ image = model(
145
+ prompt,
146
+ guidance_scale=guidance_scale,
147
+ num_inference_steps=num_inference_steps,
148
+ ).images[0]
149
+
150
+ return image, None, None, seed
151
+ else:
152
+ # SD3 models with custom parameters
153
+ if compare_mode:
154
+ set_seed(seed)
155
+ image1 = model(
156
+ prompt,
157
+ guidance_scale=guidance_scale,
158
+ num_inference_steps=num_inference_steps,
159
+ use_cfg_zero_star=True,
160
+ use_zero_init=use_zero_init,
161
+ zero_steps=zero_steps
162
+ ).images[0]
163
+
164
+ set_seed(seed)
165
+ image2 = model(
166
+ prompt,
167
+ guidance_scale=guidance_scale,
168
+ num_inference_steps=num_inference_steps,
169
+ use_cfg_zero_star=False,
170
+ use_zero_init=use_zero_init,
171
+ zero_steps=zero_steps
172
+ ).images[0]
173
+
174
+ return image1, image2, None, seed
175
+ else:
176
+ image = model(
177
+ prompt,
178
+ guidance_scale=guidance_scale,
179
+ num_inference_steps=num_inference_steps,
180
+ use_cfg_zero_star=use_cfg_zero_star,
181
+ use_zero_init=use_zero_init,
182
+ zero_steps=zero_steps
183
+ ).images[0]
184
+
185
+ if use_cfg_zero_star:
186
+ return image, None, None, seed
187
+ else:
188
+ return None, image, None, seed
189
+
190
+ # Gradio UI with left-right layout
191
+ with gr.Blocks() as demo:
192
+ gr.HTML("""
193
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
194
+ CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models
195
+ </div>
196
+ <div style="text-align: center;">
197
+ <a href="https://github.com/WeichenFan/CFG-Zero-star">Code</a> |
198
+ <a href="https://arxiv.org/abs/2503.18886">Paper</a>
199
+ </div>
200
+ """)
201
+
202
+ with gr.Row():
203
+ with gr.Column(scale=1):
204
+ prompt = gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt")
205
+ model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model")
206
+ guidance_scale = gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale")
207
+ inference_steps = gr.Slider(10, 100, value=50, step=5, label="Inference Steps")
208
+ use_opt_scale = gr.Checkbox(value=True, label="Use Optimized-Scale")
209
+ use_zero_init = gr.Checkbox(value=True, label="Use Zero Init")
210
+ zero_steps = gr.Slider(0, 20, value=1, step=1, label="Zero out steps")
211
+ seed = gr.Number(value=42, label="Seed (Leave blank for random)")
212
+ compare_mode = gr.Checkbox(value=True, label="Compare Mode")
213
+ generate_btn = gr.Button("Generate")
214
+
215
+ with gr.Column(scale=2):
216
+ out1 = gr.Image(type="pil", label="CFG-Zero* Image")
217
+ out2 = gr.Image(type="pil", label="CFG Image")
218
+ video = gr.Video(label="Video")
219
+ used_seed = gr.Textbox(label="Used Seed")
220
+
221
+ def update_params(model_name):
222
+ print('model_name: ', model_name)
223
+ if model_name == "wan-t2v":
224
+ return (
225
+ gr.update(value=5),
226
+ gr.update(value=50),
227
+ gr.update(value=True),
228
+ gr.update(value=True),
229
+ gr.update(value=1)
230
+ )
231
+ else:
232
+ return (
233
+ gr.update(value=4.0),
234
+ gr.update(value=50),
235
+ gr.update(value=True),
236
+ gr.update(value=True),
237
+ gr.update(value=1)
238
+ )
239
+
240
+ model_choice.change(
241
+ fn=update_params,
242
+ inputs=[model_choice],
243
+ outputs=[guidance_scale, inference_steps, use_opt_scale, use_zero_init, zero_steps]
244
+ )
245
+
246
+ generate_btn.click(
247
+ fn=generate_content,
248
+ inputs=[
249
+ prompt, model_choice, guidance_scale, inference_steps,
250
+ use_opt_scale, use_zero_init, zero_steps, seed, compare_mode
251
+ ],
252
+ outputs=[out1, out2, video, used_seed]
253
+ )
254
+
255
+ demo.launch(ssr_mode=False)