rahul7star commited on
Commit
ece32ef
·
verified ·
1 Parent(s): 73264d5

Update app_exp.py

Browse files
Files changed (1) hide show
  1. app_exp.py +34 -17
app_exp.py CHANGED
@@ -9,40 +9,57 @@ import torch
9
  import gradio as gr
10
  from torchvision.io import write_video
11
 
12
- # LongCat-Video imports
13
  REPO_PATH = "LongCat-Video"
14
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
15
 
 
16
  if not os.path.exists(REPO_PATH):
17
- subprocess.run(
18
- ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
19
- check=True
20
- )
21
-
 
 
 
 
 
 
 
 
22
  sys.path.insert(0, os.path.abspath(REPO_PATH))
 
 
23
  from huggingface_hub import snapshot_download
24
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
25
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
26
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
27
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
28
  from longcat_video.context_parallel import context_parallel_util
29
- import cache_dit
30
-
31
  from transformers import AutoTokenizer, UMT5EncoderModel
32
  from diffusers.utils import export_to_video
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
36
 
37
- # --- Download weights if missing ---
38
- if not os.path.exists(CHECKPOINT_DIR):
39
- snapshot_download(
40
- repo_id="meituan-longcat/LongCat-Video",
41
- local_dir=CHECKPOINT_DIR,
42
- local_dir_use_symlinks=False,
43
- ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
44
- )
45
-
46
  # --- Initialize models ---
47
  cp_split_hw = context_parallel_util.get_optimal_split(1)
48
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
 
9
  import gradio as gr
10
  from torchvision.io import write_video
11
 
12
+ # Define paths
13
  REPO_PATH = "LongCat-Video"
14
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
15
 
16
+ # Clone the repository if it doesn't exist
17
  if not os.path.exists(REPO_PATH):
18
+ print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
19
+ try:
20
+ subprocess.run(
21
+ ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
22
+ check=True,
23
+ capture_output=True
24
+ )
25
+ print("Repository cloned successfully.")
26
+ except subprocess.CalledProcessError as e:
27
+ print(f"Error cloning repository: {e.stderr.decode()}")
28
+ sys.exit(1)
29
+
30
+ # Add the cloned repository to the Python path to allow imports
31
  sys.path.insert(0, os.path.abspath(REPO_PATH))
32
+
33
+ # Now that the repo is in the path, we can import its modules
34
  from huggingface_hub import snapshot_download
35
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
36
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
37
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
38
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
39
  from longcat_video.context_parallel import context_parallel_util
 
 
40
  from transformers import AutoTokenizer, UMT5EncoderModel
41
  from diffusers.utils import export_to_video
42
 
43
+ # Download model weights from Hugging Face Hub if they don't exist
44
+ if not os.path.exists(CHECKPOINT_DIR):
45
+ print(f"Downloading model weights to '{CHECKPOINT_DIR}'...")
46
+ try:
47
+ snapshot_download(
48
+ repo_id="meituan-longcat/LongCat-Video",
49
+ local_dir=CHECKPOINT_DIR,
50
+ local_dir_use_symlinks=False, # Use False for better Windows compatibility
51
+ ignore_patterns=["*.md", "*.gitattributes", "assets/*"] # ignore non-essential files
52
+ )
53
+ print("Model weights downloaded successfully.")
54
+ except Exception as e:
55
+ print(f"Error downloading model weights: {e}")
56
+ sys.exit(1)
57
+
58
+ # Global placeholder for the pipeline and device configuration
59
+ pipe = None
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
62
 
 
 
 
 
 
 
 
 
 
63
  # --- Initialize models ---
64
  cp_split_hw = context_parallel_util.get_optimal_split(1)
65
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)