multimodalart HF Staff commited on
Commit
f6764d3
·
verified ·
1 Parent(s): 7eeedf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -53
app.py CHANGED
@@ -28,59 +28,6 @@ if not os.path.exists(REPO_PATH):
28
  print(f"Error cloning repository: {e.stderr.decode()}")
29
  sys.exit(1)
30
 
31
- attention_file_path = os.path.join(REPO_PATH, "longcat_video", "modules", "attention.py")
32
- try:
33
- with open(attention_file_path, "r") as f:
34
- content = f.read()
35
-
36
- # Original code block that we need to replace
37
- original_code = """ x, *_ = flash_attn_func(
38
- q,
39
- k,
40
- v,
41
- softmax_scale=self.scale,
42
- )
43
- x = rearrange(x, "B S H D -> B H S D")"""
44
-
45
- # Corrected code block to handle FA3's 3D output shape
46
- corrected_code = """ x, *_ = flash_attn_func(
47
- q,
48
- k,
49
- v,
50
- softmax_scale=self.scale,
51
- )
52
- # The output of FA3's flash_attn_func can be 3D (total_tokens, H, D).
53
- # We need to robustly reshape it back to the 4D format (B, S, H, D) that the
54
- # subsequent rearrange operation expects.
55
- if x.ndim == 3:
56
- # B is the original batch size from the input q tensor
57
- B = q.shape[0]
58
- # S_total is the flattened batch and sequence length
59
- S_total, H, D = x.shape
60
- # Calculate the sequence length per batch item
61
- S = S_total // B
62
- x = x.view(B, S, H, D)
63
-
64
- x = rearrange(x, "B S H D -> B H S D")"""
65
-
66
- if original_code in content:
67
- print("="*50)
68
- print("Applying file patch to attention.py for FlashAttention-3 compatibility.")
69
- content = content.replace(original_code, corrected_code)
70
- with open(attention_file_path, "w") as f:
71
- f.write(content)
72
- print("Patch applied successfully.")
73
- print("="*50)
74
- else:
75
- print("Attention.py already seems to be patched or has changed. Skipping patch.")
76
-
77
- except FileNotFoundError:
78
- print(f"Error: Could not find {attention_file_path} to patch.")
79
- sys.exit(1)
80
-
81
- # Add the cloned repository to the Python path to allow imports
82
- sys.path.insert(0, os.path.abspath(REPO_PATH))
83
-
84
  # Now that the repo is in the path, we can import its modules
85
  from huggingface_hub import snapshot_download
86
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
 
28
  print(f"Error cloning repository: {e.stderr.decode()}")
29
  sys.exit(1)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Now that the repo is in the path, we can import its modules
32
  from huggingface_hub import snapshot_download
33
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline