LongCat-Video / app_exp.py
rahul7star's picture
Update app_exp.py
b573d19 verified
raw
history blame
10.5 kB
import gradio as gr
import torch
import os
import sys
import subprocess
import tempfile
import numpy as np
import site
import importlib
from PIL import Image
from huggingface_hub import snapshot_download, hf_hub_download
# ============================================================
# 0️⃣ Install required packages
# ============================================================
subprocess.run(["pip3", "install", "-U", "cache-dit"], check=True)
import cache_dit
# ============================================================
# 1️⃣ Repository & Weights
# ============================================================
REPO_PATH = "LongCat-Video"
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
if not os.path.exists(REPO_PATH):
subprocess.run(
["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
check=True
)
sys.path.insert(0, os.path.abspath(REPO_PATH))
from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel import context_parallel_util
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.utils import export_to_video
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
if not os.path.exists(CHECKPOINT_DIR):
snapshot_download(
repo_id="meituan-longcat/LongCat-Video",
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False,
ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
)
# ============================================================
# 2️⃣ Device & Models (with cache & quantization)
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
pipe = None
try:
cp_split_hw = context_parallel_util.get_optimal_split(1)
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
# Text encoder with 4-bit quantization
text_encoder = UMT5EncoderModel.from_pretrained(
CHECKPOINT_DIR,
subfolder="text_encoder",
torch_dtype=torch_dtype,
quantization_config=TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype
)
)
vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
# DiT model with FP8/4-bit quantization + cache
dit = LongCatVideoTransformer3DModel.from_pretrained(
CHECKPOINT_DIR,
enable_flashattn3=enable_fa3,
enable_xformers=True,
subfolder="dit",
cp_split_hw=cp_split_hw,
torch_dtype=torch_dtype
)
# Enable Cache-DiT
cache_dit.enable_cache(
cache_dit.BlockAdapter(
transformer=dit,
blocks=dit.blocks,
forward_pattern=cache_dit.ForwardPattern.Pattern_3,
check_forward_pattern=False,
has_separate_cfg=False
),
cache_config=cache_dit.DBCacheConfig(
Fn_compute_blocks=1,
Bn_compute_blocks=1,
max_warmup_steps=5,
max_cached_steps=50,
max_continuous_cached_steps=50,
residual_diff_threshold=0.01,
num_inference_steps=50
)
)
pipe = LongCatVideoPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
dit=dit,
)
pipe.to(device)
print("✅ Models loaded with Cache-DiT and quantization")
except Exception as e:
print(f"❌ Failed to load models: {e}")
pipe = None
# ============================================================
# 3️⃣ Generation Helper
# ============================================================
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def check_duration(
mode,
prompt,
neg_prompt,
image,
height, width, resolution,
seed,
use_distill,
use_refine,
progress
):
if use_distill and resolution=="480p":
return 180
elif resolution=="720p":
return 360
else:
return 900
@spaces.GPU(duration=180)
def generate_video(mode, prompt, neg_prompt, image, height, width, resolution,
seed, use_distill, use_refine, duration_sec, progress=gr.Progress(track_tqdm=True)):
if pipe is None:
raise gr.Error("Models not loaded")
fps = 15 if use_distill else 30
num_frames = int(duration_sec * fps)
generator = torch.Generator(device=device).manual_seed(int(seed))
is_distill = use_distill or use_refine
progress(0.2, desc="Stage 1: Base Video Generation")
pipe.dit.enable_loras(['cfg_step_lora'] if is_distill else [])
num_inference_steps = 12 if is_distill else 24
guidance_scale = 2.0 if is_distill else 4.0
curr_neg_prompt = "" if is_distill else neg_prompt
if mode=="t2v":
output = pipe.generate_t2v(
prompt=prompt,
negative_prompt=curr_neg_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
use_distill=is_distill,
guidance_scale=guidance_scale,
generator=generator
)[0]
else:
pil_img = Image.fromarray(image)
output = pipe.generate_i2v(
image=pil_img,
prompt=prompt,
negative_prompt=curr_neg_prompt,
resolution=resolution,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
use_distill=is_distill,
guidance_scale=guidance_scale,
generator=generator
)[0]
pipe.dit.disable_all_loras()
torch_gc()
if use_refine:
progress(0.5, desc="Stage 2: Refinement")
pipe.dit.enable_loras(['refinement_lora'])
pipe.dit.enable_bsa()
stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
refine_image = Image.fromarray(image) if mode=='i2v' else None
output = pipe.generate_refine(
image=refine_image,
prompt=prompt,
stage1_video=stage1_video_pil,
num_cond_frames=1 if mode=='i2v' else 0,
num_inference_steps=50,
generator=generator
)[0]
pipe.dit.disable_all_loras()
pipe.dit.disable_bsa()
torch_gc()
progress(1.0, desc="Exporting video")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
export_to_video(output, f.name, fps=fps)
return f.name
# ============================================================
# 4️⃣ Gradio UI
# ============================================================
css=".fillable{max-width:960px !important}"
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🎬 LongCat-Video with Cache-DiT & Quantization")
gr.Markdown("13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]")
with gr.Tabs():
# Text-to-Video
with gr.TabItem("Text-to-Video"):
mode_t2v = gr.State("t2v")
with gr.Row():
with gr.Column(scale=2):
prompt_t2v = gr.Textbox(label="Prompt", lines=4)
neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
height_t2v = gr.Slider(256,1024,step=64,value=480,label="Height")
width_t2v = gr.Slider(256,1024,step=64,value=832,label="Width")
seed_t2v = gr.Number(value=42,label="Seed")
distill_t2v = gr.Checkbox(value=True,label="Use Distill Mode")
refine_t2v = gr.Checkbox(value=False,label="Use Refine Mode")
duration_t2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
t2v_button = gr.Button("Generate Video")
with gr.Column(scale=3):
video_output_t2v = gr.Video(label="Generated Video")
# Image-to-Video
with gr.TabItem("Image-to-Video"):
mode_i2v = gr.State("i2v")
with gr.Row():
with gr.Column(scale=2):
image_i2v = gr.Image(type="numpy", label="Input Image")
prompt_i2v = gr.Textbox(label="Prompt", lines=4)
neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
resolution_i2v = gr.Dropdown(["480p","720p"], value="480p", label="Resolution")
seed_i2v = gr.Number(value=42,label="Seed")
distill_i2v = gr.Checkbox(value=True,label="Use Distill Mode")
refine_i2v = gr.Checkbox(value=False,label="Use Refine Mode")
duration_i2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
i2v_button = gr.Button("Generate Video")
with gr.Column(scale=3):
video_output_i2v = gr.Video(label="Generated Video")
# Bind events
t2v_button.click(
generate_video,
inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
height_t2v, width_t2v, gr.State("480p"),
seed_t2v, distill_t2v, refine_t2v, duration_t2v],
outputs=video_output_t2v
)
i2v_button.click(
generate_video,
inputs=[mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
gr.State(None), gr.State(None), resolution_i2v,
seed_i2v, distill_i2v, refine_i2v, duration_i2v],
outputs=video_output_i2v
)
# Launch
if __name__=="__main__":
demo.launch()