rahul7star commited on
Commit
2ca29b4
·
verified ·
1 Parent(s): 211f2c9

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +47 -0
app_flash.py CHANGED
@@ -153,6 +153,53 @@ def train_flashpack_model(
153
  # 5️⃣ Load or train model
154
  # ============================================================
155
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  try:
157
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
158
  model = GemmaTrainer.from_flashpack(hf_repo)
 
153
  # 5️⃣ Load or train model
154
  # ============================================================
155
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
156
+ input_dim = 1536 # must match the input_dim used during training
157
+ try:
158
+ print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
159
+
160
+ # 1️⃣ Try local model first
161
+ local_model_path = "model.flashpack"
162
+ if os.path.exists(local_model_path):
163
+ print("✅ Loading local model")
164
+ else:
165
+ # 2️⃣ Try Hugging Face
166
+ files = list_repo_files(hf_repo)
167
+ if "model.flashpack" in files:
168
+ print("✅ Downloading model from HF")
169
+ from huggingface_hub import hf_hub_download
170
+ local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
171
+ else:
172
+ print("🚫 No pretrained model found")
173
+ return None, None, None, None
174
+
175
+ # 3️⃣ Load model with correct input_dim
176
+ model = GemmaTrainer(input_dim=input_dim).from_flashpack(local_model_path)
177
+ model.eval()
178
+
179
+ # 4️⃣ Build encoder
180
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
181
+
182
+ # 5️⃣ Enhancement function
183
+ @torch.no_grad()
184
+ def enhance_fn(prompt, chat):
185
+ chat = chat or []
186
+ short_emb = encode_fn(prompt).to(device)
187
+ mapped = model(short_emb).cpu()
188
+ long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
189
+ chat.append({"role": "user", "content": prompt})
190
+ chat.append({"role": "assistant", "content": long_prompt})
191
+ return chat
192
+
193
+ return model, tokenizer, embed_model, enhance_fn
194
+
195
+ except Exception as e:
196
+ print(f"⚠️ Load failed: {e}")
197
+ print("⏬ Training a new FlashPack model locally...")
198
+ model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
199
+ push_flashpack_model_to_hf(model, hf_repo, log_fn=print)
200
+ return model, tokenizer, embed_model, None
201
+
202
+ def get_flashpack_model1(hf_repo="rahul7star/FlashPack"):
203
  try:
204
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
205
  model = GemmaTrainer.from_flashpack(hf_repo)