File size: 7,448 Bytes
5114e38
b2baa38
 
 
35a5c71
b2baa38
 
35a5c71
 
 
 
11d4598
b2baa38
35a5c71
b2baa38
 
 
 
35a5c71
b2baa38
 
 
 
 
 
 
35a5c71
b2baa38
35a5c71
b2baa38
 
35a5c71
 
 
 
 
 
 
b2baa38
 
35a5c71
b2baa38
35a5c71
 
718db9a
35a5c71
 
 
 
718db9a
35a5c71
 
 
718db9a
 
 
 
35a5c71
 
b2baa38
 
35a5c71
b2baa38
 
35a5c71
 
b2baa38
 
35a5c71
 
 
 
 
718db9a
 
b2baa38
 
 
 
 
35a5c71
76410ab
 
35a5c71
b2baa38
35a5c71
b2baa38
 
35a5c71
b2baa38
35a5c71
 
 
b2baa38
35a5c71
 
 
 
 
3e5f523
35a5c71
 
3e5f523
35a5c71
 
 
3e5f523
 
 
b2baa38
 
718db9a
b2baa38
 
35a5c71
b2baa38
 
 
35a5c71
 
76410ab
b2baa38
 
35a5c71
b2baa38
35a5c71
b2baa38
35a5c71
 
b2baa38
35a5c71
 
 
 
 
b2baa38
 
 
76410ab
 
718db9a
35a5c71
 
76410ab
 
35a5c71
 
76410ab
 
b2baa38
 
35a5c71
 
 
718db9a
35a5c71
718db9a
b2baa38
 
35a5c71
b2baa38
35a5c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76410ab
718db9a
b2baa38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import spaces
import os
import sys
import tempfile
import datetime
import numpy as np
from PIL import Image
import gradio as gr
import torch
import torch.distributed as dist
from torchvision.io import write_video

# ============================================================
# 1️⃣ Repo & checkpoint paths
# ============================================================
REPO_PATH = "LongCat-Video"
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
if not os.path.exists(REPO_PATH):
    subprocess.run(["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True)

sys.path.insert(0, os.path.abspath(REPO_PATH))

from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel.context_parallel_util import init_context_parallel
from longcat_video.context_parallel import context_parallel_util
import cache_dit
from transformers import AutoTokenizer, UMT5EncoderModel

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32

def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

# ============================================================
# 2️⃣ Model loader with cache & 4-bit/FP8 quantization
# ============================================================
def load_models(checkpoint_dir=CHECKPOINT_DIR, cp_size=1, quantize=True, cache=True):
    cp_split_hw = context_parallel_util.get_optimal_split(cp_size)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, subfolder="tokenizer", torch_dtype=torch_dtype)
    text_encoder = UMT5EncoderModel.from_pretrained(checkpoint_dir, subfolder="text_encoder", torch_dtype=torch_dtype)
    vae = AutoencoderKLWan.from_pretrained(checkpoint_dir, subfolder="vae", torch_dtype=torch_dtype)
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(checkpoint_dir, subfolder="scheduler", torch_dtype=torch_dtype)

    if quantize:
        from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
        quant_cfg = DiffusersBitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype
        )
    else:
        quant_cfg = None

    dit = LongCatVideoTransformer3DModel.from_pretrained(
        checkpoint_dir,
        subfolder="dit",
        cp_split_hw=cp_split_hw,
        torch_dtype=torch_dtype,
        quantization_config=quant_cfg
    )

    if cache:
        from cache_dit import enable_cache, BlockAdapter, ForwardPattern, DBCacheConfig
        enable_cache(
            BlockAdapter(transformer=dit, blocks=dit.blocks, forward_pattern=ForwardPattern.Pattern_3),
            cache_config=DBCacheConfig(Fn_compute_blocks=1)
        )

    pipe = LongCatVideoPipeline(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        scheduler=scheduler,
        dit=dit
    )
    pipe.to(device)
    return pipe

pipe = load_models()

# ============================================================
# 3️⃣ LoRA refinement
# ============================================================
pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
pipe.dit.enable_loras(['refinement_lora'])
pipe.dit.enable_bsa()

# ============================================================
# 4️⃣ Video generation function
# ============================================================
@spaces.GPU(duration=60)
def generate_video(
    mode,
    prompt,
    neg_prompt,
    image,
    height,
    width,
    num_frames,
    seed,
    use_refine,
):
    generator = torch.Generator(device=device).manual_seed(int(seed))

    if mode=="t2v":
        output = pipe.generate_t2v(
            prompt=prompt,
            negative_prompt=neg_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=50,
            guidance_scale=4.0,
            generator=generator
        )[0]
    else:
        pil_image = Image.fromarray(image)
        output = pipe.generate_i2v(
            image=pil_image,
            prompt=prompt,
            negative_prompt=neg_prompt,
            resolution=f"{height}x{width}",
            num_frames=num_frames,
            num_inference_steps=50,
            guidance_scale=4.0,
            generator=generator,
            use_kv_cache=True,
            offload_kv_cache=False
        )[0]

    if use_refine:
        pipe.dit.enable_loras(['refinement_lora'])
        pipe.dit.enable_bsa()
        stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
        stage1_video_pil = [Image.fromarray(f) for f in stage1_video_pil]

        output = pipe.generate_refine(
            stage1_video=stage1_video_pil,
            prompt=prompt,
            num_cond_frames=1,
            num_inference_steps=50,
            generator=generator
        )[0]

    output_tensor = torch.from_numpy(np.array(output))
    output_tensor = (output_tensor*255).clamp(0,255).to(torch.uint8)

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
        write_video(f.name, output_tensor, fps=15, video_codec="libx264", options={"crf": "18"})
        return f.name

# ============================================================
# 5️⃣ Gradio interface
# ============================================================
with gr.Blocks() as demo:
    gr.Markdown("# 🎬 Optimized LongCat-Video Demo (FA3 removed)")
    with gr.Tab("Text-to-Video"):
        prompt_t2v = gr.Textbox(label="Prompt", lines=3)
        neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
        height_t2v = gr.Slider(256,1024,value=480,step=64,label="Height")
        width_t2v = gr.Slider(256,1024,value=832,step=64,label="Width")
        frames_t2v = gr.Slider(8,180,value=48,step=1,label="Frames")
        seed_t2v = gr.Number(value=42,label="Seed",precision=0)
        refine_t2v = gr.Checkbox(label="Use Refine",value=True)
        out_t2v = gr.Video(label="Generated Video")
        btn_t2v = gr.Button("Generate")
        btn_t2v.click(
            generate_video,
            inputs=["t2v", prompt_t2v, neg_prompt_t2v, None, height_t2v, width_t2v, frames_t2v, seed_t2v, refine_t2v],
            outputs=out_t2v
        )
    with gr.Tab("Image-to-Video"):
        image_i2v = gr.Image(type="numpy")
        prompt_i2v = gr.Textbox(label="Prompt", lines=3)
        neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
        frames_i2v = gr.Slider(8,180,value=48,step=1,label="Frames")
        seed_i2v = gr.Number(value=42,label="Seed",precision=0)
        refine_i2v = gr.Checkbox(label="Use Refine",value=True)
        out_i2v = gr.Video(label="Generated Video")
        btn_i2v = gr.Button("Generate")
        btn_i2v.click(
            generate_video,
            inputs=["i2v", prompt_i2v, neg_prompt_i2v, image_i2v, 480, 832, frames_i2v, seed_i2v, refine_i2v],
            outputs=out_i2v
        )

if __name__=="__main__":
    demo.launch()