import spaces import gradio as gr import torch import os import sys import subprocess import tempfile import numpy as np import spaces from PIL import Image # Define paths REPO_PATH = "LongCat-Video" CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video") # Clone repo if missing if not os.path.exists(REPO_PATH): print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...") subprocess.run( ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True ) sys.path.insert(0, os.path.abspath(REPO_PATH)) # Imports from LongCat repo from huggingface_hub import snapshot_download from longcat_video.pipeline_longcat_video import LongCatVideoPipeline from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel from longcat_video.context_parallel import context_parallel_util from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.utils import export_to_video from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig # Download model weights if missing if not os.path.exists(CHECKPOINT_DIR): snapshot_download( repo_id="meituan-longcat/LongCat-Video", local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False, ignore_patterns=["*.md", "*.gitattributes", "assets/*"] ) pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 print("--- Initializing Models ---") try: cp_split_hw = context_parallel_util.get_optimal_split(1) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype) text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype) vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype) # ✅ 4-bit quantization enabled bnb_4bit_config = DiffusersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) dit = LongCatVideoTransformer3DModel.from_pretrained( CHECKPOINT_DIR, enable_flashattn3=False, enable_flashattn2=False, enable_xformers=True, subfolder="dit", cp_split_hw=cp_split_hw, torch_dtype=torch_dtype, #quantization_config=bnb_4bit_config # ✅ added ) pipe = LongCatVideoPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, dit=dit, ).to(device) pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors'), 'cfg_step_lora') pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora') print("--- Models loaded successfully ---") except Exception as e: print("❌ Model load error:", e) pipe = None # -------------------- GPU Cleanup -------------------- def torch_gc(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # -------------------- Video Generation -------------------- def check_duration(*_args, duration_t2v=2, **_kwargs): fps = 30 return duration_t2v * fps +30 @spaces.GPU(duration=check_duration) def generate_video( mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, duration_t2v=2, progress=gr.Progress(track_tqdm=True) ): if pipe is None: raise gr.Error("Models failed to load.") generator = torch.Generator(device=device).manual_seed(int(seed)) num_frames = int(duration_t2v * 30) # ✅ duration-based frame count print(prompt) is_distill = use_distill or use_refine if is_distill: pipe.dit.enable_loras(['cfg_step_lora']) num_inference_steps = 16 guidance_scale = 1.0 neg = "" else: num_inference_steps = 50 guidance_scale = 4.0 neg = neg_prompt if mode == "t2v": output = pipe.generate_t2v( prompt=prompt, negative_prompt=neg, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] else: pil_image = Image.fromarray(image) output = pipe.generate_i2v( image=pil_image, prompt=prompt, negative_prompt=neg, resolution=resolution, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] if is_distill: pipe.dit.disable_all_loras() torch_gc() if use_refine: progress(0.5, desc="Refining") pipe.dit.enable_loras(['refinement_lora']) pipe.dit.enable_bsa() frames = [(frame * 255).astype(np.uint8) for frame in output] frames = [Image.fromarray(f) for f in frames] ref_img = Image.fromarray(image) if mode == "i2v" else None output = pipe.generate_refine( image=ref_img, prompt=prompt, stage1_video=frames, num_cond_frames=1 if mode == "i2v" else 0, num_inference_steps=50, generator=generator, )[0] pipe.dit.disable_all_loras() pipe.dit.disable_bsa() torch_gc() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: export_to_video(output, tmp.name, fps=30) print("video generatwd") return tmp.name # -------------------- Gradio UI -------------------- css = ".fillable{max-width:960px !important}" with gr.Blocks(css=css) as demo: gr.Markdown("# 🎬 LongCat-Video") gr.Markdown("13.6B parameter dense video-generation model — [HuggingFace](https://huggingface.co/meituan-longcat/LongCat-Video)") with gr.Tabs(): # --- T2V --- with gr.TabItem("Text-to-Video"): mode_t2v = gr.State("t2v") prompt_t2v = gr.Textbox(label="Prompt", lines=4) neg_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles") height_t2v = gr.Slider(256, 1024, value=480, step=64, label="Height") width_t2v = gr.Slider(256, 1024, value=832, step=64, label="Width") seed_t2v = gr.Number(label="Seed", value=42) distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True) refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False) duration_t2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)") # ✅ added t2v_button = gr.Button("Generate Video") video_out_t2v = gr.Video(label="Generated Video") t2v_button.click( fn=generate_video, inputs=[mode_t2v, prompt_t2v, neg_t2v, gr.State(None), height_t2v, width_t2v, gr.State(None), seed_t2v, distill_t2v, refine_t2v, duration_t2v], outputs=video_out_t2v ) # --- I2V --- with gr.TabItem("Image-to-Video"): mode_i2v = gr.State("i2v") image_i2v = gr.Image(type="numpy", label="Input Image") prompt_i2v = gr.Textbox(label="Prompt", lines=4) neg_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark") resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution") seed_i2v = gr.Number(label="Seed", value=42) distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True) refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False) duration_i2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)") # ✅ added i2v_button = gr.Button("Generate Video") video_out_i2v = gr.Video(label="Generated Video") i2v_button.click( fn=generate_video, inputs=[mode_i2v, prompt_i2v, neg_i2v, image_i2v, gr.State(None), gr.State(None), resolution_i2v, seed_i2v, distill_i2v, refine_i2v, duration_i2v], outputs=video_out_i2v ) if __name__ == "__main__": demo.launch()