Update app_flash.py
Browse files- app_flash.py +47 -0
app_flash.py
CHANGED
|
@@ -153,6 +153,53 @@ def train_flashpack_model(
|
|
| 153 |
# 5️⃣ Load or train model
|
| 154 |
# ============================================================
|
| 155 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
try:
|
| 157 |
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 158 |
model = GemmaTrainer.from_flashpack(hf_repo)
|
|
|
|
| 153 |
# 5️⃣ Load or train model
|
| 154 |
# ============================================================
|
| 155 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 156 |
+
input_dim = 1536 # must match the input_dim used during training
|
| 157 |
+
try:
|
| 158 |
+
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 159 |
+
|
| 160 |
+
# 1️⃣ Try local model first
|
| 161 |
+
local_model_path = "model.flashpack"
|
| 162 |
+
if os.path.exists(local_model_path):
|
| 163 |
+
print("✅ Loading local model")
|
| 164 |
+
else:
|
| 165 |
+
# 2️⃣ Try Hugging Face
|
| 166 |
+
files = list_repo_files(hf_repo)
|
| 167 |
+
if "model.flashpack" in files:
|
| 168 |
+
print("✅ Downloading model from HF")
|
| 169 |
+
from huggingface_hub import hf_hub_download
|
| 170 |
+
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
|
| 171 |
+
else:
|
| 172 |
+
print("🚫 No pretrained model found")
|
| 173 |
+
return None, None, None, None
|
| 174 |
+
|
| 175 |
+
# 3️⃣ Load model with correct input_dim
|
| 176 |
+
model = GemmaTrainer(input_dim=input_dim).from_flashpack(local_model_path)
|
| 177 |
+
model.eval()
|
| 178 |
+
|
| 179 |
+
# 4️⃣ Build encoder
|
| 180 |
+
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
|
| 181 |
+
|
| 182 |
+
# 5️⃣ Enhancement function
|
| 183 |
+
@torch.no_grad()
|
| 184 |
+
def enhance_fn(prompt, chat):
|
| 185 |
+
chat = chat or []
|
| 186 |
+
short_emb = encode_fn(prompt).to(device)
|
| 187 |
+
mapped = model(short_emb).cpu()
|
| 188 |
+
long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
|
| 189 |
+
chat.append({"role": "user", "content": prompt})
|
| 190 |
+
chat.append({"role": "assistant", "content": long_prompt})
|
| 191 |
+
return chat
|
| 192 |
+
|
| 193 |
+
return model, tokenizer, embed_model, enhance_fn
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"⚠️ Load failed: {e}")
|
| 197 |
+
print("⏬ Training a new FlashPack model locally...")
|
| 198 |
+
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
|
| 199 |
+
push_flashpack_model_to_hf(model, hf_repo, log_fn=print)
|
| 200 |
+
return model, tokenizer, embed_model, None
|
| 201 |
+
|
| 202 |
+
def get_flashpack_model1(hf_repo="rahul7star/FlashPack"):
|
| 203 |
try:
|
| 204 |
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
|
| 205 |
model = GemmaTrainer.from_flashpack(hf_repo)
|