Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import sys | |
| import tempfile | |
| import datetime | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import torch.distributed as dist | |
| from torchvision.io import write_video | |
| # ============================================================ | |
| # 1️⃣ Repo & checkpoint paths | |
| # ============================================================ | |
| REPO_PATH = "LongCat-Video" | |
| CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video") | |
| if not os.path.exists(REPO_PATH): | |
| subprocess.run(["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True) | |
| sys.path.insert(0, os.path.abspath(REPO_PATH)) | |
| from longcat_video.pipeline_longcat_video import LongCatVideoPipeline | |
| from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | |
| from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan | |
| from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel | |
| from longcat_video.context_parallel.context_parallel_util import init_context_parallel | |
| from longcat_video.context_parallel import context_parallel_util | |
| import cache_dit | |
| from transformers import AutoTokenizer, UMT5EncoderModel | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32 | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| # ============================================================ | |
| # 2️⃣ Model loader with cache & 4-bit/FP8 quantization | |
| # ============================================================ | |
| def load_models(checkpoint_dir=CHECKPOINT_DIR, cp_size=1, quantize=True, cache=True): | |
| cp_split_hw = context_parallel_util.get_optimal_split(cp_size) | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, subfolder="tokenizer", torch_dtype=torch_dtype) | |
| text_encoder = UMT5EncoderModel.from_pretrained(checkpoint_dir, subfolder="text_encoder", torch_dtype=torch_dtype) | |
| vae = AutoencoderKLWan.from_pretrained(checkpoint_dir, subfolder="vae", torch_dtype=torch_dtype) | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(checkpoint_dir, subfolder="scheduler", torch_dtype=torch_dtype) | |
| if quantize: | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| quant_cfg = DiffusersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype | |
| ) | |
| else: | |
| quant_cfg = None | |
| dit = LongCatVideoTransformer3DModel.from_pretrained( | |
| checkpoint_dir, | |
| subfolder="dit", | |
| cp_split_hw=cp_split_hw, | |
| torch_dtype=torch_dtype, | |
| quantization_config=quant_cfg | |
| ) | |
| if cache: | |
| from cache_dit import enable_cache, BlockAdapter, ForwardPattern, DBCacheConfig | |
| enable_cache( | |
| BlockAdapter(transformer=dit, blocks=dit.blocks, forward_pattern=ForwardPattern.Pattern_3), | |
| cache_config=DBCacheConfig(Fn_compute_blocks=1) | |
| ) | |
| pipe = LongCatVideoPipeline( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| scheduler=scheduler, | |
| dit=dit | |
| ) | |
| pipe.to(device) | |
| return pipe | |
| pipe = load_models() | |
| # ============================================================ | |
| # 3️⃣ LoRA refinement | |
| # ============================================================ | |
| pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora') | |
| pipe.dit.enable_loras(['refinement_lora']) | |
| pipe.dit.enable_bsa() | |
| # ============================================================ | |
| # 4️⃣ Video generation function | |
| # ============================================================ | |
| def generate_video( | |
| mode, | |
| prompt, | |
| neg_prompt, | |
| image, | |
| height, | |
| width, | |
| num_frames, | |
| seed, | |
| use_refine, | |
| ): | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| if mode=="t2v": | |
| output = pipe.generate_t2v( | |
| prompt=prompt, | |
| negative_prompt=neg_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=50, | |
| guidance_scale=4.0, | |
| generator=generator | |
| )[0] | |
| else: | |
| pil_image = Image.fromarray(image) | |
| output = pipe.generate_i2v( | |
| image=pil_image, | |
| prompt=prompt, | |
| negative_prompt=neg_prompt, | |
| resolution=f"{height}x{width}", | |
| num_frames=num_frames, | |
| num_inference_steps=50, | |
| guidance_scale=4.0, | |
| generator=generator, | |
| use_kv_cache=True, | |
| offload_kv_cache=False | |
| )[0] | |
| if use_refine: | |
| pipe.dit.enable_loras(['refinement_lora']) | |
| pipe.dit.enable_bsa() | |
| stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output] | |
| stage1_video_pil = [Image.fromarray(f) for f in stage1_video_pil] | |
| output = pipe.generate_refine( | |
| stage1_video=stage1_video_pil, | |
| prompt=prompt, | |
| num_cond_frames=1, | |
| num_inference_steps=50, | |
| generator=generator | |
| )[0] | |
| output_tensor = torch.from_numpy(np.array(output)) | |
| output_tensor = (output_tensor*255).clamp(0,255).to(torch.uint8) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: | |
| write_video(f.name, output_tensor, fps=15, video_codec="libx264", options={"crf": "18"}) | |
| return f.name | |
| # ============================================================ | |
| # 5️⃣ Gradio interface | |
| # ============================================================ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎬 Optimized LongCat-Video Demo (FA3 removed)") | |
| with gr.Tab("Text-to-Video"): | |
| prompt_t2v = gr.Textbox(label="Prompt", lines=3) | |
| neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality") | |
| height_t2v = gr.Slider(256,1024,value=480,step=64,label="Height") | |
| width_t2v = gr.Slider(256,1024,value=832,step=64,label="Width") | |
| frames_t2v = gr.Slider(8,180,value=48,step=1,label="Frames") | |
| seed_t2v = gr.Number(value=42,label="Seed",precision=0) | |
| refine_t2v = gr.Checkbox(label="Use Refine",value=True) | |
| out_t2v = gr.Video(label="Generated Video") | |
| btn_t2v = gr.Button("Generate") | |
| btn_t2v.click( | |
| generate_video, | |
| inputs=["t2v", prompt_t2v, neg_prompt_t2v, None, height_t2v, width_t2v, frames_t2v, seed_t2v, refine_t2v], | |
| outputs=out_t2v | |
| ) | |
| with gr.Tab("Image-to-Video"): | |
| image_i2v = gr.Image(type="numpy") | |
| prompt_i2v = gr.Textbox(label="Prompt", lines=3) | |
| neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality") | |
| frames_i2v = gr.Slider(8,180,value=48,step=1,label="Frames") | |
| seed_i2v = gr.Number(value=42,label="Seed",precision=0) | |
| refine_i2v = gr.Checkbox(label="Use Refine",value=True) | |
| out_i2v = gr.Video(label="Generated Video") | |
| btn_i2v = gr.Button("Generate") | |
| btn_i2v.click( | |
| generate_video, | |
| inputs=["i2v", prompt_i2v, neg_prompt_i2v, image_i2v, 480, 832, frames_i2v, seed_i2v, refine_i2v], | |
| outputs=out_i2v | |
| ) | |
| if __name__=="__main__": | |
| demo.launch() | |