rahul7star commited on
Commit
b2baa38
·
verified ·
1 Parent(s): 7ccc31d

Create app_exp.py

Browse files
Files changed (1) hide show
  1. app_exp.py +268 -0
app_exp.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import sys
5
+ import subprocess
6
+ import tempfile
7
+ import numpy as np
8
+ import spaces
9
+ import importlib
10
+ import site
11
+ from PIL import Image
12
+ from huggingface_hub import snapshot_download, hf_hub_download
13
+
14
+ # ============================================================
15
+ # 1️⃣ FlashAttention 3 Setup (Auto-install from HF repo)
16
+ # ============================================================
17
+ try:
18
+ print("Attempting to download and install FlashAttention 3 wheel...")
19
+ fa3_wheel = hf_hub_download(
20
+ repo_id="rahul7star/flash-attn-3",
21
+ repo_type="model",
22
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
23
+ )
24
+ subprocess.run(["pip", "install", fa3_wheel], check=True)
25
+ site.addsitedir(site.getsitepackages()[0])
26
+ importlib.invalidate_caches()
27
+ print("✅ FlashAttention 3 installed successfully.")
28
+ except Exception as e:
29
+ print(f"⚠️ FlashAttention install failed: {e}")
30
+ print("Proceeding without FA3 acceleration...")
31
+
32
+ # ============================================================
33
+ # 2️⃣ Define model and repo paths
34
+ # ============================================================
35
+ REPO_PATH = "LongCat-Video"
36
+ CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
37
+
38
+ # ============================================================
39
+ # 3️⃣ Clone the model repo if needed
40
+ # ============================================================
41
+ if not os.path.exists(REPO_PATH):
42
+ print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
43
+ subprocess.run(
44
+ ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
45
+ check=True
46
+ )
47
+ print("✅ Repository cloned successfully.")
48
+
49
+ # Make repo importable
50
+ sys.path.insert(0, os.path.abspath(REPO_PATH))
51
+
52
+ # ============================================================
53
+ # 4️⃣ Import model modules after repo setup
54
+ # ============================================================
55
+ from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
56
+ from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
57
+ from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
58
+ from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
59
+ from longcat_video.context_parallel import context_parallel_util
60
+ from transformers import AutoTokenizer, UMT5EncoderModel
61
+ from diffusers.utils import export_to_video
62
+
63
+ # ============================================================
64
+ # 5️⃣ Download weights (snapshot)
65
+ # ============================================================
66
+ if not os.path.exists(CHECKPOINT_DIR):
67
+ print(f"Downloading model weights to '{CHECKPOINT_DIR}'...")
68
+ snapshot_download(
69
+ repo_id="meituan-longcat/LongCat-Video",
70
+ local_dir=CHECKPOINT_DIR,
71
+ local_dir_use_symlinks=False,
72
+ ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
73
+ )
74
+ print("✅ Model weights ready.")
75
+
76
+ # ============================================================
77
+ # 6️⃣ Initialize model pipeline
78
+ # ============================================================
79
+ pipe = None
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
82
+
83
+ print("--- Initializing Models (once at startup) ---")
84
+ try:
85
+ cp_split_hw = context_parallel_util.get_optimal_split(1)
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
88
+ text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
89
+ vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
90
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
91
+
92
+ # ✅ Enable FA3 acceleration
93
+ dit = LongCatVideoTransformer3DModel.from_pretrained(
94
+ CHECKPOINT_DIR,
95
+ enable_flashattn3=True,
96
+ enable_flashattn2=False,
97
+ enable_xformers=True,
98
+ subfolder="dit",
99
+ cp_split_hw=cp_split_hw,
100
+ torch_dtype=torch_dtype,
101
+ )
102
+
103
+ pipe = LongCatVideoPipeline(
104
+ tokenizer=tokenizer,
105
+ text_encoder=text_encoder,
106
+ vae=vae,
107
+ scheduler=scheduler,
108
+ dit=dit,
109
+ ).to(device)
110
+
111
+ # Load LoRAs
112
+ lora_dir = os.path.join(CHECKPOINT_DIR, "lora")
113
+ pipe.dit.load_lora(os.path.join(lora_dir, "cfg_step_lora.safetensors"), "cfg_step_lora")
114
+ pipe.dit.load_lora(os.path.join(lora_dir, "refinement_lora.safetensors"), "refinement_lora")
115
+
116
+ print("✅ Models loaded successfully.")
117
+ except Exception as e:
118
+ print(f"❌ FATAL: Model initialization failed.\n{e}")
119
+ pipe = None
120
+
121
+ # ============================================================
122
+ # 7️⃣ GPU cleanup utility
123
+ # ============================================================
124
+ def torch_gc():
125
+ if torch.cuda.is_available():
126
+ torch.cuda.empty_cache()
127
+ torch.cuda.ipc_collect()
128
+
129
+ # ============================================================
130
+ # 8️⃣ Dynamic GPU duration logic
131
+ # ============================================================
132
+ def compute_duration(mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, progress):
133
+ """
134
+ Adaptive GPU time allocation based on resolution & refinement usage.
135
+ """
136
+ base = 120 # baseline (seconds)
137
+ if resolution == "720p": base += 60
138
+ if use_refine: base += 60
139
+ if use_distill: base -= 30
140
+ return min(base, 240) # cap at 4 min
141
+
142
+ # ============================================================
143
+ # 9️⃣ Generation function
144
+ # ============================================================
145
+ @spaces.GPU(duration=compute_duration)
146
+ def generate_video(
147
+ mode,
148
+ prompt,
149
+ neg_prompt,
150
+ image,
151
+ height, width, resolution,
152
+ seed,
153
+ use_distill,
154
+ use_refine,
155
+ progress=gr.Progress(track_tqdm=True)
156
+ ):
157
+ if pipe is None:
158
+ raise gr.Error("⚠️ Models failed to load. Restart the app.")
159
+
160
+ generator = torch.Generator(device=device).manual_seed(int(seed))
161
+ num_frames = 48 # shorter for faster test runs
162
+
163
+ is_distill = use_distill or use_refine
164
+ pipe.dit.enable_loras(["cfg_step_lora"] if is_distill else [])
165
+
166
+ num_inference_steps = 12 if is_distill else 24
167
+ guidance_scale = 2.0 if is_distill else 4.0
168
+
169
+ # --- Stage 1 ---
170
+ progress(0.2, desc="Stage 1: Generating Base Video...")
171
+ if mode == "t2v":
172
+ output = pipe.generate_t2v(
173
+ prompt=prompt,
174
+ negative_prompt=neg_prompt,
175
+ height=height,
176
+ width=width,
177
+ num_frames=num_frames,
178
+ num_inference_steps=num_inference_steps,
179
+ use_distill=is_distill,
180
+ guidance_scale=guidance_scale,
181
+ generator=generator,
182
+ )[0]
183
+ else:
184
+ pil_img = Image.fromarray(image)
185
+ output = pipe.generate_i2v(
186
+ image=pil_img,
187
+ prompt=prompt,
188
+ negative_prompt=neg_prompt,
189
+ resolution=resolution,
190
+ num_frames=num_frames,
191
+ num_inference_steps=num_inference_steps,
192
+ use_distill=is_distill,
193
+ guidance_scale=guidance_scale,
194
+ generator=generator,
195
+ )[0]
196
+
197
+ pipe.dit.disable_all_loras()
198
+ torch_gc()
199
+
200
+ # --- Stage 2 ---
201
+ if use_refine:
202
+ progress(0.6, desc="Stage 2: Refining Video...")
203
+ pipe.dit.enable_loras(["refinement_lora"])
204
+ refined = pipe.generate_refine(
205
+ image=Image.fromarray(image) if mode == "i2v" else None,
206
+ prompt=prompt,
207
+ stage1_video=[Image.fromarray((f * 255).astype(np.uint8)) for f in output],
208
+ num_cond_frames=1 if mode == "i2v" else 0,
209
+ num_inference_steps=20,
210
+ generator=generator,
211
+ )[0]
212
+ output = refined
213
+ pipe.dit.disable_all_loras()
214
+ torch_gc()
215
+
216
+ # --- Export ---
217
+ progress(1.0, desc="Exporting video...")
218
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_vid:
219
+ export_to_video(output, tmp_vid.name, fps=24)
220
+ return tmp_vid.name
221
+
222
+ # ============================================================
223
+ # 🔟 Gradio UI
224
+ # ============================================================
225
+ css = ".fillable{max-width:960px!important}"
226
+ with gr.Blocks(css=css) as demo:
227
+ gr.Markdown("# 🎬 LongCat-Video + FA3 Accelerated 🚀")
228
+ gr.Markdown("13.6B parameter dense video model — with FlashAttention 3 for speed ⚡")
229
+
230
+ with gr.Tabs():
231
+ # Text-to-Video
232
+ with gr.TabItem("Text-to-Video"):
233
+ prompt_t2v = gr.Textbox(label="Prompt", lines=3, placeholder="A cinematic shot of a corgi running on the beach.")
234
+ neg_t2v = gr.Textbox(label="Negative Prompt", value="ugly, blurry, static")
235
+ h_t2v = gr.Slider(256, 1024, 480, step=64, label="Height")
236
+ w_t2v = gr.Slider(256, 1024, 832, step=64, label="Width")
237
+ seed_t2v = gr.Number(value=42, label="Seed")
238
+ distill_t2v = gr.Checkbox(label="Distill Mode", value=True)
239
+ refine_t2v = gr.Checkbox(label="Refine Mode", value=False)
240
+ btn_t2v = gr.Button("Generate Video", variant="primary")
241
+ out_t2v = gr.Video(label="Output Video")
242
+
243
+ btn_t2v.click(
244
+ generate_video,
245
+ inputs=["t2v", prompt_t2v, neg_t2v, gr.State(None), h_t2v, w_t2v, gr.State("480p"), seed_t2v, distill_t2v, refine_t2v],
246
+ outputs=out_t2v,
247
+ )
248
+
249
+ # Image-to-Video
250
+ with gr.TabItem("Image-to-Video"):
251
+ img_i2v = gr.Image(type="numpy", label="Input Image")
252
+ prompt_i2v = gr.Textbox(label="Prompt", placeholder="The cat in the image blinks.")
253
+ neg_i2v = gr.Textbox(label="Negative Prompt", value="ugly, blurry")
254
+ resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution")
255
+ seed_i2v = gr.Number(value=42, label="Seed")
256
+ distill_i2v = gr.Checkbox(label="Distill Mode", value=True)
257
+ refine_i2v = gr.Checkbox(label="Refine Mode", value=False)
258
+ btn_i2v = gr.Button("Generate Video", variant="primary")
259
+ out_i2v = gr.Video(label="Output Video")
260
+
261
+ btn_i2v.click(
262
+ generate_video,
263
+ inputs=["i2v", prompt_i2v, neg_i2v, img_i2v, gr.State(None), gr.State(None), resolution_i2v, seed_i2v, distill_i2v, refine_i2v],
264
+ outputs=out_i2v,
265
+ )
266
+
267
+ if __name__ == "__main__":
268
+ demo.launch()