Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,8 @@ app = FastAPI(
|
|
| 28 |
t5_tokenizer = None
|
| 29 |
t5_model = None
|
| 30 |
mistral = None
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
# Root endpoint
|
| 34 |
@app.get("/")
|
|
@@ -44,62 +45,73 @@ async def root():
|
|
| 44 |
async def health_check():
|
| 45 |
logger.info(f"Health check endpoint called at {time.time()}")
|
| 46 |
return JSONResponse(
|
| 47 |
-
content={"status": "healthy" if
|
| 48 |
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
| 49 |
)
|
| 50 |
|
| 51 |
-
# Async function to load
|
| 52 |
-
async def
|
| 53 |
-
global t5_tokenizer, t5_model,
|
| 54 |
start_time = time.time()
|
| 55 |
-
logger.info(f"Starting model loading at {start_time}")
|
| 56 |
try:
|
| 57 |
-
# Load T5 model from local cache
|
| 58 |
T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
|
| 59 |
logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
|
| 60 |
t5_tokenizer = AutoTokenizer.from_pretrained(
|
| 61 |
T5_MODEL_PATH,
|
| 62 |
-
local_files_only=True
|
|
|
|
| 63 |
)
|
| 64 |
logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
|
| 65 |
logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
|
| 66 |
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 67 |
T5_MODEL_PATH,
|
| 68 |
-
local_files_only=True
|
|
|
|
| 69 |
)
|
| 70 |
logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
|
| 74 |
if not os.path.exists(gguf_path):
|
| 75 |
logger.error(f"Mistral GGUF file not found at {gguf_path}")
|
| 76 |
-
raise RuntimeError(
|
| 77 |
-
f"Mistral GGUF file not found at {gguf_path}. "
|
| 78 |
-
"تأكد من أن ملف setup.sh تم تنفيذه أثناء الـ build."
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
logger.info(f"Loading Mistral model from {gguf_path}")
|
| 82 |
mistral = Llama(
|
| 83 |
model_path=gguf_path,
|
| 84 |
-
n_ctx=512,
|
| 85 |
-
n_threads=1,
|
| 86 |
-
n_batch=128,
|
| 87 |
verbose=True
|
| 88 |
)
|
| 89 |
logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
|
| 90 |
-
|
| 91 |
except Exception as e:
|
| 92 |
-
logger.error(f"Failed to load
|
| 93 |
-
|
|
|
|
| 94 |
finally:
|
| 95 |
end_time = time.time()
|
| 96 |
-
logger.info(f"
|
| 97 |
|
| 98 |
-
# Run model loading in the background
|
| 99 |
@app.on_event("startup")
|
| 100 |
async def startup_event():
|
| 101 |
logger.info(f"Startup event triggered at {time.time()}")
|
| 102 |
-
asyncio.create_task(
|
| 103 |
|
| 104 |
# Define request schema
|
| 105 |
class AskRequest(BaseModel):
|
|
@@ -110,9 +122,9 @@ class AskRequest(BaseModel):
|
|
| 110 |
@app.post("/ask")
|
| 111 |
async def ask(req: AskRequest):
|
| 112 |
logger.info(f"Received ask request: {req.question} at {time.time()}")
|
| 113 |
-
if not
|
| 114 |
-
logger.error("
|
| 115 |
-
raise HTTPException(status_code=503, detail="
|
| 116 |
|
| 117 |
q = req.question.strip()
|
| 118 |
if not q:
|
|
@@ -121,14 +133,20 @@ async def ask(req: AskRequest):
|
|
| 121 |
|
| 122 |
try:
|
| 123 |
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
|
| 124 |
-
#
|
| 125 |
logger.info("Using MGZON-FLAN-T5 model")
|
| 126 |
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
|
| 127 |
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
|
| 128 |
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 129 |
model_name = "MGZON-FLAN-T5"
|
| 130 |
else:
|
| 131 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
logger.info("Using Mistral-7B-GGUF model")
|
| 133 |
out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
|
| 134 |
answer = out["choices"][0]["text"].strip()
|
|
@@ -136,7 +154,7 @@ async def ask(req: AskRequest):
|
|
| 136 |
logger.info(f"Response generated by {model_name}: {answer}")
|
| 137 |
return {"model": model_name, "response": answer}
|
| 138 |
except Exception as e:
|
| 139 |
-
logger.error(f"Error processing request: {str(e)}")
|
| 140 |
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
|
| 141 |
|
| 142 |
# Run the app
|
|
@@ -147,7 +165,7 @@ if __name__ == "__main__":
|
|
| 147 |
port=8080,
|
| 148 |
log_level="info",
|
| 149 |
workers=1,
|
| 150 |
-
timeout_keep_alive=15,
|
| 151 |
-
limit_concurrency=5,
|
| 152 |
-
limit_max_requests=50
|
| 153 |
)
|
|
|
|
| 28 |
t5_tokenizer = None
|
| 29 |
t5_model = None
|
| 30 |
mistral = None
|
| 31 |
+
t5_loaded = False
|
| 32 |
+
mistral_loaded = False
|
| 33 |
|
| 34 |
# Root endpoint
|
| 35 |
@app.get("/")
|
|
|
|
| 45 |
async def health_check():
|
| 46 |
logger.info(f"Health check endpoint called at {time.time()}")
|
| 47 |
return JSONResponse(
|
| 48 |
+
content={"status": "healthy" if t5_loaded else "loading"},
|
| 49 |
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
| 50 |
)
|
| 51 |
|
| 52 |
+
# Async function to load T5 model
|
| 53 |
+
async def load_t5_model():
|
| 54 |
+
global t5_tokenizer, t5_model, t5_loaded
|
| 55 |
start_time = time.time()
|
| 56 |
+
logger.info(f"Starting T5 model loading at {start_time}")
|
| 57 |
try:
|
|
|
|
| 58 |
T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
|
| 59 |
logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
|
| 60 |
t5_tokenizer = AutoTokenizer.from_pretrained(
|
| 61 |
T5_MODEL_PATH,
|
| 62 |
+
local_files_only=True,
|
| 63 |
+
torch_dtype="float16" # Reduce memory usage
|
| 64 |
)
|
| 65 |
logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
|
| 66 |
logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
|
| 67 |
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 68 |
T5_MODEL_PATH,
|
| 69 |
+
local_files_only=True,
|
| 70 |
+
torch_dtype="float16" # Reduce memory usage
|
| 71 |
)
|
| 72 |
logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
|
| 73 |
+
t5_loaded = True
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True)
|
| 76 |
+
t5_loaded = False
|
| 77 |
+
raise RuntimeError(f"Failed to load T5 model: {str(e)}")
|
| 78 |
+
finally:
|
| 79 |
+
end_time = time.time()
|
| 80 |
+
logger.info(f"T5 model loading completed in {end_time - start_time} seconds")
|
| 81 |
|
| 82 |
+
# Async function to load Mistral model
|
| 83 |
+
async def load_mistral_model():
|
| 84 |
+
global mistral, mistral_loaded
|
| 85 |
+
start_time = time.time()
|
| 86 |
+
logger.info(f"Starting Mistral model loading at {start_time}")
|
| 87 |
+
try:
|
| 88 |
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
|
| 89 |
if not os.path.exists(gguf_path):
|
| 90 |
logger.error(f"Mistral GGUF file not found at {gguf_path}")
|
| 91 |
+
raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
logger.info(f"Loading Mistral model from {gguf_path}")
|
| 93 |
mistral = Llama(
|
| 94 |
model_path=gguf_path,
|
| 95 |
+
n_ctx=512,
|
| 96 |
+
n_threads=1,
|
| 97 |
+
n_batch=128,
|
| 98 |
verbose=True
|
| 99 |
)
|
| 100 |
logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
|
| 101 |
+
mistral_loaded = True
|
| 102 |
except Exception as e:
|
| 103 |
+
logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True)
|
| 104 |
+
mistral_loaded = False
|
| 105 |
+
raise RuntimeError(f"Failed to load Mistral model: {str(e)}")
|
| 106 |
finally:
|
| 107 |
end_time = time.time()
|
| 108 |
+
logger.info(f"Mistral model loading completed in {end_time - start_time} seconds")
|
| 109 |
|
| 110 |
+
# Run T5 model loading in the background
|
| 111 |
@app.on_event("startup")
|
| 112 |
async def startup_event():
|
| 113 |
logger.info(f"Startup event triggered at {time.time()}")
|
| 114 |
+
asyncio.create_task(load_t5_model()) # Load only T5 at startup
|
| 115 |
|
| 116 |
# Define request schema
|
| 117 |
class AskRequest(BaseModel):
|
|
|
|
| 122 |
@app.post("/ask")
|
| 123 |
async def ask(req: AskRequest):
|
| 124 |
logger.info(f"Received ask request: {req.question} at {time.time()}")
|
| 125 |
+
if not t5_loaded:
|
| 126 |
+
logger.error("T5 model not loaded yet")
|
| 127 |
+
raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later")
|
| 128 |
|
| 129 |
q = req.question.strip()
|
| 130 |
if not q:
|
|
|
|
| 133 |
|
| 134 |
try:
|
| 135 |
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
|
| 136 |
+
# Use T5 model
|
| 137 |
logger.info("Using MGZON-FLAN-T5 model")
|
| 138 |
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
|
| 139 |
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
|
| 140 |
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 141 |
model_name = "MGZON-FLAN-T5"
|
| 142 |
else:
|
| 143 |
+
# Load Mistral model if not loaded
|
| 144 |
+
if not mistral_loaded:
|
| 145 |
+
logger.info("Mistral model not loaded, loading now...")
|
| 146 |
+
await load_mistral_model()
|
| 147 |
+
if not mistral_loaded:
|
| 148 |
+
raise HTTPException(status_code=503, detail="Failed to load Mistral model")
|
| 149 |
+
# Use Mistral model
|
| 150 |
logger.info("Using Mistral-7B-GGUF model")
|
| 151 |
out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
|
| 152 |
answer = out["choices"][0]["text"].strip()
|
|
|
|
| 154 |
logger.info(f"Response generated by {model_name}: {answer}")
|
| 155 |
return {"model": model_name, "response": answer}
|
| 156 |
except Exception as e:
|
| 157 |
+
logger.error(f"Error processing request: {str(e)}", exc_info=True)
|
| 158 |
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
|
| 159 |
|
| 160 |
# Run the app
|
|
|
|
| 165 |
port=8080,
|
| 166 |
log_level="info",
|
| 167 |
workers=1,
|
| 168 |
+
timeout_keep_alive=15,
|
| 169 |
+
limit_concurrency=5,
|
| 170 |
+
limit_max_requests=50
|
| 171 |
)
|