import gradio as gr import torch import os import sys import subprocess import tempfile import numpy as np import spaces import importlib import site from PIL import Image from huggingface_hub import snapshot_download, hf_hub_download # ============================================================ # 1️⃣ FlashAttention 3 Setup (Auto-install from HF repo) # ============================================================ try: print("Attempting to download and install FlashAttention 3 wheel...") fa3_wheel = hf_hub_download( repo_id="rahul7star/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) subprocess.run(["pip", "install", fa3_wheel], check=True) site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() print("✅ FlashAttention 3 installed successfully.") except Exception as e: print(f"⚠️ FlashAttention install failed: {e}") print("Proceeding without FA3 acceleration...") # ============================================================ # 2️⃣ Define model and repo paths # ============================================================ REPO_PATH = "LongCat-Video" CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video") # ============================================================ # 3️⃣ Clone the model repo if needed # ============================================================ if not os.path.exists(REPO_PATH): print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...") subprocess.run( ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True ) print("✅ Repository cloned successfully.") # Make repo importable sys.path.insert(0, os.path.abspath(REPO_PATH)) # ============================================================ # 4️⃣ Import model modules after repo setup # ============================================================ from longcat_video.pipeline_longcat_video import LongCatVideoPipeline from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel from longcat_video.context_parallel import context_parallel_util from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.utils import export_to_video # ============================================================ # 5️⃣ Download weights (snapshot) # ============================================================ if not os.path.exists(CHECKPOINT_DIR): print(f"Downloading model weights to '{CHECKPOINT_DIR}'...") snapshot_download( repo_id="meituan-longcat/LongCat-Video", local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False, ignore_patterns=["*.md", "*.gitattributes", "assets/*"] ) print("✅ Model weights ready.") # ============================================================ # 6️⃣ Initialize model pipeline # ============================================================ pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 print("--- Initializing Models (once at startup) ---") try: cp_split_hw = context_parallel_util.get_optimal_split(1) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype) text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype) vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype) # ✅ Enable FA3 acceleration dit = LongCatVideoTransformer3DModel.from_pretrained( CHECKPOINT_DIR, enable_flashattn3=True, enable_flashattn2=False, enable_xformers=True, subfolder="dit", cp_split_hw=cp_split_hw, torch_dtype=torch_dtype, ) pipe = LongCatVideoPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, dit=dit, ).to(device) # Load LoRAs lora_dir = os.path.join(CHECKPOINT_DIR, "lora") pipe.dit.load_lora(os.path.join(lora_dir, "cfg_step_lora.safetensors"), "cfg_step_lora") pipe.dit.load_lora(os.path.join(lora_dir, "refinement_lora.safetensors"), "refinement_lora") print("✅ Models loaded successfully.") except Exception as e: print(f"❌ FATAL: Model initialization failed.\n{e}") pipe = None # ============================================================ # 7️⃣ GPU cleanup utility # ============================================================ def torch_gc(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # ============================================================ # 8️⃣ Dynamic GPU duration logic # ============================================================ def compute_duration(mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, progress): """ Adaptive GPU time allocation based on resolution & refinement usage. """ base = 120 # baseline (seconds) if resolution == "720p": base += 60 if use_refine: base += 60 if use_distill: base -= 30 return min(base, 240) # cap at 4 min # ============================================================ # 9️⃣ Generation function # ============================================================ @spaces.GPU(duration=compute_duration) def generate_video( mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, progress=gr.Progress(track_tqdm=True) ): if pipe is None: raise gr.Error("⚠️ Models failed to load. Restart the app.") generator = torch.Generator(device=device).manual_seed(int(seed)) num_frames = 48 # shorter for faster test runs is_distill = use_distill or use_refine pipe.dit.enable_loras(["cfg_step_lora"] if is_distill else []) num_inference_steps = 12 if is_distill else 24 guidance_scale = 2.0 if is_distill else 4.0 # --- Stage 1 --- progress(0.2, desc="Stage 1: Generating Base Video...") if mode == "t2v": output = pipe.generate_t2v( prompt=prompt, negative_prompt=neg_prompt, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] else: pil_img = Image.fromarray(image) output = pipe.generate_i2v( image=pil_img, prompt=prompt, negative_prompt=neg_prompt, resolution=resolution, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] pipe.dit.disable_all_loras() torch_gc() # --- Stage 2 --- if use_refine: progress(0.6, desc="Stage 2: Refining Video...") pipe.dit.enable_loras(["refinement_lora"]) refined = pipe.generate_refine( image=Image.fromarray(image) if mode == "i2v" else None, prompt=prompt, stage1_video=[Image.fromarray((f * 255).astype(np.uint8)) for f in output], num_cond_frames=1 if mode == "i2v" else 0, num_inference_steps=20, generator=generator, )[0] output = refined pipe.dit.disable_all_loras() torch_gc() # --- Export --- progress(1.0, desc="Exporting video...") with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_vid: export_to_video(output, tmp_vid.name, fps=24) return tmp_vid.name # ============================================================ # 🔟 Gradio UI # ============================================================ css = ".fillable{max-width:960px!important}" with gr.Blocks(css=css) as demo: gr.Markdown("# 🎬 LongCat-Video + FA3 Accelerated 🚀") gr.Markdown("13.6B parameter dense video model — with FlashAttention 3 for speed ⚡") with gr.Tabs(): # Text-to-Video with gr.TabItem("Text-to-Video"): prompt_t2v = gr.Textbox(label="Prompt", lines=3, placeholder="A cinematic shot of a corgi running on the beach.") neg_t2v = gr.Textbox(label="Negative Prompt", value="ugly, blurry, static") h_t2v = gr.Slider(256, 1024, 480, step=64, label="Height") w_t2v = gr.Slider(256, 1024, 832, step=64, label="Width") seed_t2v = gr.Number(value=42, label="Seed") distill_t2v = gr.Checkbox(label="Distill Mode", value=True) refine_t2v = gr.Checkbox(label="Refine Mode", value=False) btn_t2v = gr.Button("Generate Video", variant="primary") out_t2v = gr.Video(label="Output Video") btn_t2v.click( generate_video, inputs=["t2v", prompt_t2v, neg_t2v, gr.State(None), h_t2v, w_t2v, gr.State("480p"), seed_t2v, distill_t2v, refine_t2v], outputs=out_t2v, ) # Image-to-Video with gr.TabItem("Image-to-Video"): img_i2v = gr.Image(type="numpy", label="Input Image") prompt_i2v = gr.Textbox(label="Prompt", placeholder="The cat in the image blinks.") neg_i2v = gr.Textbox(label="Negative Prompt", value="ugly, blurry") resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution") seed_i2v = gr.Number(value=42, label="Seed") distill_i2v = gr.Checkbox(label="Distill Mode", value=True) refine_i2v = gr.Checkbox(label="Refine Mode", value=False) btn_i2v = gr.Button("Generate Video", variant="primary") out_i2v = gr.Video(label="Output Video") btn_i2v.click( generate_video, inputs=["i2v", prompt_i2v, neg_i2v, img_i2v, gr.State(None), gr.State(None), resolution_i2v, seed_i2v, distill_i2v, refine_i2v], outputs=out_i2v, ) if __name__ == "__main__": demo.launch()