MGZON commited on
Commit
b3daae1
·
verified ·
1 Parent(s): 10775fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -33
app.py CHANGED
@@ -28,7 +28,8 @@ app = FastAPI(
28
  t5_tokenizer = None
29
  t5_model = None
30
  mistral = None
31
- models_loaded = False
 
32
 
33
  # Root endpoint
34
  @app.get("/")
@@ -44,62 +45,73 @@ async def root():
44
  async def health_check():
45
  logger.info(f"Health check endpoint called at {time.time()}")
46
  return JSONResponse(
47
- content={"status": "healthy" if models_loaded else "loading"},
48
  headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
49
  )
50
 
51
- # Async function to load models
52
- async def load_models():
53
- global t5_tokenizer, t5_model, mistral, models_loaded
54
  start_time = time.time()
55
- logger.info(f"Starting model loading at {start_time}")
56
  try:
57
- # Load T5 model from local cache
58
  T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
59
  logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
60
  t5_tokenizer = AutoTokenizer.from_pretrained(
61
  T5_MODEL_PATH,
62
- local_files_only=True
 
63
  )
64
  logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
65
  logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
66
  t5_model = AutoModelForSeq2SeqLM.from_pretrained(
67
  T5_MODEL_PATH,
68
- local_files_only=True
 
69
  )
70
  logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
 
 
 
 
 
 
 
 
71
 
72
- # Load Mistral GGUF model
 
 
 
 
 
73
  gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
74
  if not os.path.exists(gguf_path):
75
  logger.error(f"Mistral GGUF file not found at {gguf_path}")
76
- raise RuntimeError(
77
- f"Mistral GGUF file not found at {gguf_path}. "
78
- "تأكد من أن ملف setup.sh تم تنفيذه أثناء الـ build."
79
- )
80
-
81
  logger.info(f"Loading Mistral model from {gguf_path}")
82
  mistral = Llama(
83
  model_path=gguf_path,
84
- n_ctx=512, # قللنا n_ctx أكتر عشان نقلل الذاكرة
85
- n_threads=1, # thread واحد بس
86
- n_batch=128, # قللنا n_batch أكتر
87
  verbose=True
88
  )
89
  logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
90
- models_loaded = True
91
  except Exception as e:
92
- logger.error(f"Failed to load models: {str(e)}")
93
- raise RuntimeError(f"Failed to load models: {str(e)}")
 
94
  finally:
95
  end_time = time.time()
96
- logger.info(f"Model loading completed in {end_time - start_time} seconds")
97
 
98
- # Run model loading in the background
99
  @app.on_event("startup")
100
  async def startup_event():
101
  logger.info(f"Startup event triggered at {time.time()}")
102
- asyncio.create_task(load_models())
103
 
104
  # Define request schema
105
  class AskRequest(BaseModel):
@@ -110,9 +122,9 @@ class AskRequest(BaseModel):
110
  @app.post("/ask")
111
  async def ask(req: AskRequest):
112
  logger.info(f"Received ask request: {req.question} at {time.time()}")
113
- if not models_loaded:
114
- logger.error("Models not loaded yet")
115
- raise HTTPException(status_code=503, detail="Models are still loading, please try again later")
116
 
117
  q = req.question.strip()
118
  if not q:
@@ -121,14 +133,20 @@ async def ask(req: AskRequest):
121
 
122
  try:
123
  if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
124
- # نموذج T5
125
  logger.info("Using MGZON-FLAN-T5 model")
126
  inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
127
  out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
128
  answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
129
  model_name = "MGZON-FLAN-T5"
130
  else:
131
- # نموذج Mistral
 
 
 
 
 
 
132
  logger.info("Using Mistral-7B-GGUF model")
133
  out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
134
  answer = out["choices"][0]["text"].strip()
@@ -136,7 +154,7 @@ async def ask(req: AskRequest):
136
  logger.info(f"Response generated by {model_name}: {answer}")
137
  return {"model": model_name, "response": answer}
138
  except Exception as e:
139
- logger.error(f"Error processing request: {str(e)}")
140
  raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
141
 
142
  # Run the app
@@ -147,7 +165,7 @@ if __name__ == "__main__":
147
  port=8080,
148
  log_level="info",
149
  workers=1,
150
- timeout_keep_alive=15, # تقليل وقت keep-alive
151
- limit_concurrency=5, # تقليل الاتصالات المتزامنة أكتر
152
- limit_max_requests=50 # تقليل عدد الطلبات القصوى أكتر
153
  )
 
28
  t5_tokenizer = None
29
  t5_model = None
30
  mistral = None
31
+ t5_loaded = False
32
+ mistral_loaded = False
33
 
34
  # Root endpoint
35
  @app.get("/")
 
45
  async def health_check():
46
  logger.info(f"Health check endpoint called at {time.time()}")
47
  return JSONResponse(
48
+ content={"status": "healthy" if t5_loaded else "loading"},
49
  headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
50
  )
51
 
52
+ # Async function to load T5 model
53
+ async def load_t5_model():
54
+ global t5_tokenizer, t5_model, t5_loaded
55
  start_time = time.time()
56
+ logger.info(f"Starting T5 model loading at {start_time}")
57
  try:
 
58
  T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
59
  logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
60
  t5_tokenizer = AutoTokenizer.from_pretrained(
61
  T5_MODEL_PATH,
62
+ local_files_only=True,
63
+ torch_dtype="float16" # Reduce memory usage
64
  )
65
  logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
66
  logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
67
  t5_model = AutoModelForSeq2SeqLM.from_pretrained(
68
  T5_MODEL_PATH,
69
+ local_files_only=True,
70
+ torch_dtype="float16" # Reduce memory usage
71
  )
72
  logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
73
+ t5_loaded = True
74
+ except Exception as e:
75
+ logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True)
76
+ t5_loaded = False
77
+ raise RuntimeError(f"Failed to load T5 model: {str(e)}")
78
+ finally:
79
+ end_time = time.time()
80
+ logger.info(f"T5 model loading completed in {end_time - start_time} seconds")
81
 
82
+ # Async function to load Mistral model
83
+ async def load_mistral_model():
84
+ global mistral, mistral_loaded
85
+ start_time = time.time()
86
+ logger.info(f"Starting Mistral model loading at {start_time}")
87
+ try:
88
  gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
89
  if not os.path.exists(gguf_path):
90
  logger.error(f"Mistral GGUF file not found at {gguf_path}")
91
+ raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}")
 
 
 
 
92
  logger.info(f"Loading Mistral model from {gguf_path}")
93
  mistral = Llama(
94
  model_path=gguf_path,
95
+ n_ctx=512,
96
+ n_threads=1,
97
+ n_batch=128,
98
  verbose=True
99
  )
100
  logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
101
+ mistral_loaded = True
102
  except Exception as e:
103
+ logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True)
104
+ mistral_loaded = False
105
+ raise RuntimeError(f"Failed to load Mistral model: {str(e)}")
106
  finally:
107
  end_time = time.time()
108
+ logger.info(f"Mistral model loading completed in {end_time - start_time} seconds")
109
 
110
+ # Run T5 model loading in the background
111
  @app.on_event("startup")
112
  async def startup_event():
113
  logger.info(f"Startup event triggered at {time.time()}")
114
+ asyncio.create_task(load_t5_model()) # Load only T5 at startup
115
 
116
  # Define request schema
117
  class AskRequest(BaseModel):
 
122
  @app.post("/ask")
123
  async def ask(req: AskRequest):
124
  logger.info(f"Received ask request: {req.question} at {time.time()}")
125
+ if not t5_loaded:
126
+ logger.error("T5 model not loaded yet")
127
+ raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later")
128
 
129
  q = req.question.strip()
130
  if not q:
 
133
 
134
  try:
135
  if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
136
+ # Use T5 model
137
  logger.info("Using MGZON-FLAN-T5 model")
138
  inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
139
  out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
140
  answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
141
  model_name = "MGZON-FLAN-T5"
142
  else:
143
+ # Load Mistral model if not loaded
144
+ if not mistral_loaded:
145
+ logger.info("Mistral model not loaded, loading now...")
146
+ await load_mistral_model()
147
+ if not mistral_loaded:
148
+ raise HTTPException(status_code=503, detail="Failed to load Mistral model")
149
+ # Use Mistral model
150
  logger.info("Using Mistral-7B-GGUF model")
151
  out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
152
  answer = out["choices"][0]["text"].strip()
 
154
  logger.info(f"Response generated by {model_name}: {answer}")
155
  return {"model": model_name, "response": answer}
156
  except Exception as e:
157
+ logger.error(f"Error processing request: {str(e)}", exc_info=True)
158
  raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
159
 
160
  # Run the app
 
165
  port=8080,
166
  log_level="info",
167
  workers=1,
168
+ timeout_keep_alive=15,
169
+ limit_concurrency=5,
170
+ limit_max_requests=50
171
  )