rahul7star commited on
Commit
35a5c71
·
verified ·
1 Parent(s): 11d4598

Update app_exp.py

Browse files
Files changed (1) hide show
  1. app_exp.py +105 -214
app_exp.py CHANGED
@@ -1,55 +1,22 @@
1
  import spaces
2
- import gradio as gr
3
- import torch
4
  import os
5
  import sys
6
- import subprocess
7
  import tempfile
 
8
  import numpy as np
9
- import site
10
- import importlib
11
  from PIL import Image
12
- from huggingface_hub import snapshot_download, hf_hub_download
13
-
14
- # ============================================================
15
- # 0️⃣ Install required packages
16
- # ============================================================
17
- subprocess.run(["pip3", "install", "-U", "cache-dit"], check=True)
18
-
19
-
20
-
21
- import cache_dit
22
-
23
- enable_fa3 = False # default if FA3 cannot be loaded
24
-
25
- try:
26
- print("Installing FlashAttention 3...")
27
- flash_attention_wheel = hf_hub_download(
28
- repo_id="rahul7star/flash-attn-3",
29
- repo_type="model",
30
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
31
- )
32
- subprocess.run(["pip", "install", flash_attention_wheel], check=True)
33
- site.addsitedir(site.getsitepackages()[0])
34
- importlib.invalidate_caches()
35
- enable_fa3 = True
36
- print("✅ FlashAttention 3 installed and enabled")
37
- except Exception as e:
38
- print(f"⚠️ Could not install FlashAttention 3: {e}")
39
- # enable_fa3 remains False
40
-
41
 
42
  # ============================================================
43
- # 1️⃣ Repository & Weights
44
  # ============================================================
45
  REPO_PATH = "LongCat-Video"
46
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
47
-
48
  if not os.path.exists(REPO_PATH):
49
- subprocess.run(
50
- ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
51
- check=True
52
- )
53
 
54
  sys.path.insert(0, os.path.abspath(REPO_PATH))
55
 
@@ -57,248 +24,172 @@ from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
57
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
58
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
59
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
 
60
  from longcat_video.context_parallel import context_parallel_util
 
61
  from transformers import AutoTokenizer, UMT5EncoderModel
62
- from diffusers.utils import export_to_video
63
- from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
64
- from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
65
 
66
- if not os.path.exists(CHECKPOINT_DIR):
67
- snapshot_download(
68
- repo_id="meituan-longcat/LongCat-Video",
69
- local_dir=CHECKPOINT_DIR,
70
- local_dir_use_symlinks=False,
71
- ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
72
- )
73
 
74
  # ============================================================
75
- # 2️⃣ Device & Models (with cache & quantization)
76
  # ============================================================
77
- device = "cuda" if torch.cuda.is_available() else "cpu"
78
- torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
79
- pipe = None
80
 
81
- try:
82
- cp_split_hw = context_parallel_util.get_optimal_split(1)
83
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
 
84
 
85
- # Text encoder with 4-bit quantization
86
- text_encoder = UMT5EncoderModel.from_pretrained(
87
- CHECKPOINT_DIR,
88
- subfolder="text_encoder",
89
- torch_dtype=torch_dtype,
90
- quantization_config=TransformersBitsAndBytesConfig(
91
  load_in_4bit=True,
92
  bnb_4bit_quant_type="nf4",
93
  bnb_4bit_compute_dtype=torch_dtype
94
  )
95
- )
96
-
97
- vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
98
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
99
 
100
- # DiT model with FP8/4-bit quantization + cache
101
  dit = LongCatVideoTransformer3DModel.from_pretrained(
102
- CHECKPOINT_DIR,
103
- enable_flashattn3=enable_fa3,
104
- enable_xformers=True,
105
  subfolder="dit",
106
  cp_split_hw=cp_split_hw,
107
- torch_dtype=torch_dtype
 
108
  )
109
 
110
- # Enable Cache-DiT
111
- cache_dit.enable_cache(
112
- cache_dit.BlockAdapter(
113
- transformer=dit,
114
- blocks=dit.blocks,
115
- forward_pattern=cache_dit.ForwardPattern.Pattern_3,
116
- check_forward_pattern=False,
117
- has_separate_cfg=False
118
- ),
119
- cache_config=cache_dit.DBCacheConfig(
120
- Fn_compute_blocks=1,
121
- Bn_compute_blocks=1,
122
- max_warmup_steps=5,
123
- max_cached_steps=50,
124
- max_continuous_cached_steps=50,
125
- residual_diff_threshold=0.01,
126
- num_inference_steps=50
127
  )
128
- )
129
 
130
  pipe = LongCatVideoPipeline(
131
  tokenizer=tokenizer,
132
  text_encoder=text_encoder,
133
  vae=vae,
134
  scheduler=scheduler,
135
- dit=dit,
136
  )
137
  pipe.to(device)
138
- print("✅ Models loaded with Cache-DiT and quantization")
139
 
140
- except Exception as e:
141
- print(f"❌ Failed to load models: {e}")
142
- pipe = None
143
 
144
  # ============================================================
145
- # 3️⃣ Generation Helper
146
  # ============================================================
147
- def torch_gc():
148
- if torch.cuda.is_available():
149
- torch.cuda.empty_cache()
150
- torch.cuda.ipc_collect()
151
 
152
- def check_duration(
 
 
 
 
153
  mode,
154
- prompt,
155
- neg_prompt,
156
  image,
157
- height, width, resolution,
 
 
158
  seed,
159
- use_distill,
160
  use_refine,
161
- progress
162
  ):
163
- if use_distill and resolution=="480p":
164
- return 180
165
- elif resolution=="720p":
166
- return 360
167
- else:
168
- return 900
169
-
170
- @spaces.GPU(duration=180)
171
- def generate_video(mode, prompt, neg_prompt, image, height, width, resolution,
172
- seed, use_distill, use_refine, duration_sec, progress=gr.Progress(track_tqdm=True)):
173
-
174
- if pipe is None:
175
- raise gr.Error("Models not loaded")
176
-
177
- fps = 15 if use_distill else 30
178
- num_frames = int(duration_sec * fps)
179
  generator = torch.Generator(device=device).manual_seed(int(seed))
180
- is_distill = use_distill or use_refine
181
-
182
- progress(0.2, desc="Stage 1: Base Video Generation")
183
- pipe.dit.enable_loras(['cfg_step_lora'] if is_distill else [])
184
- num_inference_steps = 12 if is_distill else 24
185
- guidance_scale = 2.0 if is_distill else 4.0
186
- curr_neg_prompt = "" if is_distill else neg_prompt
187
 
188
  if mode=="t2v":
189
  output = pipe.generate_t2v(
190
  prompt=prompt,
191
- negative_prompt=curr_neg_prompt,
192
  height=height,
193
  width=width,
194
  num_frames=num_frames,
195
- num_inference_steps=num_inference_steps,
196
- use_distill=is_distill,
197
- guidance_scale=guidance_scale,
198
  generator=generator
199
  )[0]
200
  else:
201
- pil_img = Image.fromarray(image)
202
  output = pipe.generate_i2v(
203
- image=pil_img,
204
  prompt=prompt,
205
- negative_prompt=curr_neg_prompt,
206
- resolution=resolution,
207
  num_frames=num_frames,
208
- num_inference_steps=num_inference_steps,
209
- use_distill=is_distill,
210
- guidance_scale=guidance_scale,
211
- generator=generator
 
212
  )[0]
213
 
214
- pipe.dit.disable_all_loras()
215
- torch_gc()
216
-
217
  if use_refine:
218
- progress(0.5, desc="Stage 2: Refinement")
219
  pipe.dit.enable_loras(['refinement_lora'])
220
  pipe.dit.enable_bsa()
221
  stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
222
- stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
223
- refine_image = Image.fromarray(image) if mode=='i2v' else None
224
  output = pipe.generate_refine(
225
- image=refine_image,
226
- prompt=prompt,
227
  stage1_video=stage1_video_pil,
228
- num_cond_frames=1 if mode=='i2v' else 0,
 
229
  num_inference_steps=50,
230
  generator=generator
231
  )[0]
232
- pipe.dit.disable_all_loras()
233
- pipe.dit.disable_bsa()
234
- torch_gc()
235
 
236
- progress(1.0, desc="Exporting video")
 
 
237
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
238
- export_to_video(output, f.name, fps=fps)
239
  return f.name
240
 
241
  # ============================================================
242
- # 4️⃣ Gradio UI
243
  # ============================================================
244
- css=".fillable{max-width:960px !important}"
245
-
246
- with gr.Blocks(css=css) as demo:
247
- gr.Markdown("# 🎬 LongCat-Video with Cache-DiT & Quantization")
248
- gr.Markdown("13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]")
249
-
250
- with gr.Tabs():
251
- # Text-to-Video
252
- with gr.TabItem("Text-to-Video"):
253
- mode_t2v = gr.State("t2v")
254
- with gr.Row():
255
- with gr.Column(scale=2):
256
- prompt_t2v = gr.Textbox(label="Prompt", lines=4)
257
- neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
258
- height_t2v = gr.Slider(256,1024,step=64,value=480,label="Height")
259
- width_t2v = gr.Slider(256,1024,step=64,value=832,label="Width")
260
- seed_t2v = gr.Number(value=42,label="Seed")
261
- distill_t2v = gr.Checkbox(value=True,label="Use Distill Mode")
262
- refine_t2v = gr.Checkbox(value=False,label="Use Refine Mode")
263
- duration_t2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
264
- t2v_button = gr.Button("Generate Video")
265
- with gr.Column(scale=3):
266
- video_output_t2v = gr.Video(label="Generated Video")
267
-
268
- # Image-to-Video
269
- with gr.TabItem("Image-to-Video"):
270
- mode_i2v = gr.State("i2v")
271
- with gr.Row():
272
- with gr.Column(scale=2):
273
- image_i2v = gr.Image(type="numpy", label="Input Image")
274
- prompt_i2v = gr.Textbox(label="Prompt", lines=4)
275
- neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
276
- resolution_i2v = gr.Dropdown(["480p","720p"], value="480p", label="Resolution")
277
- seed_i2v = gr.Number(value=42,label="Seed")
278
- distill_i2v = gr.Checkbox(value=True,label="Use Distill Mode")
279
- refine_i2v = gr.Checkbox(value=False,label="Use Refine Mode")
280
- duration_i2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
281
- i2v_button = gr.Button("Generate Video")
282
- with gr.Column(scale=3):
283
- video_output_i2v = gr.Video(label="Generated Video")
284
-
285
- # Bind events
286
- t2v_button.click(
287
- generate_video,
288
- inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
289
- height_t2v, width_t2v, gr.State("480p"),
290
- seed_t2v, distill_t2v, refine_t2v, duration_t2v],
291
- outputs=video_output_t2v
292
- )
293
-
294
- i2v_button.click(
295
- generate_video,
296
- inputs=[mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
297
- gr.State(None), gr.State(None), resolution_i2v,
298
- seed_i2v, distill_i2v, refine_i2v, duration_i2v],
299
- outputs=video_output_i2v
300
- )
301
 
302
- # Launch
303
  if __name__=="__main__":
304
  demo.launch()
 
1
  import spaces
 
 
2
  import os
3
  import sys
 
4
  import tempfile
5
+ import datetime
6
  import numpy as np
 
 
7
  from PIL import Image
8
+ import gradio as gr
9
+ import torch
10
+ import torch.distributed as dist
11
+ from torchvision.io import write_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # ============================================================
14
+ # 1️⃣ Repo & checkpoint paths
15
  # ============================================================
16
  REPO_PATH = "LongCat-Video"
17
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
 
18
  if not os.path.exists(REPO_PATH):
19
+ subprocess.run(["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True)
 
 
 
20
 
21
  sys.path.insert(0, os.path.abspath(REPO_PATH))
22
 
 
24
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
25
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
26
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
27
+ from longcat_video.context_parallel.context_parallel_util import init_context_parallel
28
  from longcat_video.context_parallel import context_parallel_util
29
+ import cache_dit
30
  from transformers import AutoTokenizer, UMT5EncoderModel
 
 
 
31
 
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
34
+
35
+ def torch_gc():
36
+ if torch.cuda.is_available():
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.ipc_collect()
39
 
40
  # ============================================================
41
+ # 2️⃣ Model loader with cache & 4-bit/FP8 quantization
42
  # ============================================================
43
+ def load_models(checkpoint_dir=CHECKPOINT_DIR, cp_size=1, quantize=True, cache=True):
44
+ cp_split_hw = context_parallel_util.get_optimal_split(cp_size)
 
45
 
46
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, subfolder="tokenizer", torch_dtype=torch_dtype)
47
+ text_encoder = UMT5EncoderModel.from_pretrained(checkpoint_dir, subfolder="text_encoder", torch_dtype=torch_dtype)
48
+ vae = AutoencoderKLWan.from_pretrained(checkpoint_dir, subfolder="vae", torch_dtype=torch_dtype)
49
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(checkpoint_dir, subfolder="scheduler", torch_dtype=torch_dtype)
50
 
51
+ if quantize:
52
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
53
+ quant_cfg = DiffusersBitsAndBytesConfig(
 
 
 
54
  load_in_4bit=True,
55
  bnb_4bit_quant_type="nf4",
56
  bnb_4bit_compute_dtype=torch_dtype
57
  )
58
+ else:
59
+ quant_cfg = None
 
 
60
 
 
61
  dit = LongCatVideoTransformer3DModel.from_pretrained(
62
+ checkpoint_dir,
 
 
63
  subfolder="dit",
64
  cp_split_hw=cp_split_hw,
65
+ torch_dtype=torch_dtype,
66
+ quantization_config=quant_cfg
67
  )
68
 
69
+ if cache:
70
+ from cache_dit import enable_cache, BlockAdapter, ForwardPattern, DBCacheConfig
71
+ enable_cache(
72
+ BlockAdapter(transformer=dit, blocks=dit.blocks, forward_pattern=ForwardPattern.Pattern_3),
73
+ cache_config=DBCacheConfig(Fn_compute_blocks=1)
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
 
75
 
76
  pipe = LongCatVideoPipeline(
77
  tokenizer=tokenizer,
78
  text_encoder=text_encoder,
79
  vae=vae,
80
  scheduler=scheduler,
81
+ dit=dit
82
  )
83
  pipe.to(device)
84
+ return pipe
85
 
86
+ pipe = load_models()
 
 
87
 
88
  # ============================================================
89
+ # 3️⃣ LoRA refinement
90
  # ============================================================
91
+ pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
92
+ pipe.dit.enable_loras(['refinement_lora'])
93
+ pipe.dit.enable_bsa()
 
94
 
95
+ # ============================================================
96
+ # 4️⃣ Video generation function
97
+ # ============================================================
98
+ @spaces.GPU(duration=60)
99
+ def generate_video(
100
  mode,
101
+ prompt,
102
+ neg_prompt,
103
  image,
104
+ height,
105
+ width,
106
+ num_frames,
107
  seed,
 
108
  use_refine,
 
109
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  generator = torch.Generator(device=device).manual_seed(int(seed))
 
 
 
 
 
 
 
111
 
112
  if mode=="t2v":
113
  output = pipe.generate_t2v(
114
  prompt=prompt,
115
+ negative_prompt=neg_prompt,
116
  height=height,
117
  width=width,
118
  num_frames=num_frames,
119
+ num_inference_steps=50,
120
+ guidance_scale=4.0,
 
121
  generator=generator
122
  )[0]
123
  else:
124
+ pil_image = Image.fromarray(image)
125
  output = pipe.generate_i2v(
126
+ image=pil_image,
127
  prompt=prompt,
128
+ negative_prompt=neg_prompt,
129
+ resolution=f"{height}x{width}",
130
  num_frames=num_frames,
131
+ num_inference_steps=50,
132
+ guidance_scale=4.0,
133
+ generator=generator,
134
+ use_kv_cache=True,
135
+ offload_kv_cache=False
136
  )[0]
137
 
 
 
 
138
  if use_refine:
 
139
  pipe.dit.enable_loras(['refinement_lora'])
140
  pipe.dit.enable_bsa()
141
  stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
142
+ stage1_video_pil = [Image.fromarray(f) for f in stage1_video_pil]
143
+
144
  output = pipe.generate_refine(
 
 
145
  stage1_video=stage1_video_pil,
146
+ prompt=prompt,
147
+ num_cond_frames=1,
148
  num_inference_steps=50,
149
  generator=generator
150
  )[0]
 
 
 
151
 
152
+ output_tensor = torch.from_numpy(np.array(output))
153
+ output_tensor = (output_tensor*255).clamp(0,255).to(torch.uint8)
154
+
155
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
156
+ write_video(f.name, output_tensor, fps=15, video_codec="libx264", options={"crf": "18"})
157
  return f.name
158
 
159
  # ============================================================
160
+ # 5️⃣ Gradio interface
161
  # ============================================================
162
+ with gr.Blocks() as demo:
163
+ gr.Markdown("# 🎬 Optimized LongCat-Video Demo (FA3 removed)")
164
+ with gr.Tab("Text-to-Video"):
165
+ prompt_t2v = gr.Textbox(label="Prompt", lines=3)
166
+ neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
167
+ height_t2v = gr.Slider(256,1024,value=480,step=64,label="Height")
168
+ width_t2v = gr.Slider(256,1024,value=832,step=64,label="Width")
169
+ frames_t2v = gr.Slider(8,180,value=48,step=1,label="Frames")
170
+ seed_t2v = gr.Number(value=42,label="Seed",precision=0)
171
+ refine_t2v = gr.Checkbox(label="Use Refine",value=True)
172
+ out_t2v = gr.Video(label="Generated Video")
173
+ btn_t2v = gr.Button("Generate")
174
+ btn_t2v.click(
175
+ generate_video,
176
+ inputs=["t2v", prompt_t2v, neg_prompt_t2v, None, height_t2v, width_t2v, frames_t2v, seed_t2v, refine_t2v],
177
+ outputs=out_t2v
178
+ )
179
+ with gr.Tab("Image-to-Video"):
180
+ image_i2v = gr.Image(type="numpy")
181
+ prompt_i2v = gr.Textbox(label="Prompt", lines=3)
182
+ neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
183
+ frames_i2v = gr.Slider(8,180,value=48,step=1,label="Frames")
184
+ seed_i2v = gr.Number(value=42,label="Seed",precision=0)
185
+ refine_i2v = gr.Checkbox(label="Use Refine",value=True)
186
+ out_i2v = gr.Video(label="Generated Video")
187
+ btn_i2v = gr.Button("Generate")
188
+ btn_i2v.click(
189
+ generate_video,
190
+ inputs=["i2v", prompt_i2v, neg_prompt_i2v, image_i2v, 480, 832, frames_i2v, seed_i2v, refine_i2v],
191
+ outputs=out_i2v
192
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
 
194
  if __name__=="__main__":
195
  demo.launch()