import os import logging import time from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from llama_cpp import Llama import asyncio import uvicorn # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Set up cache directory CACHE_DIR = "/app/.cache/huggingface/hub" os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.environ["HF_HOME"] = CACHE_DIR # Create the FastAPI app app = FastAPI( title="MGZON Smart Assistant", description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space" ) # Initialize model variables t5_tokenizer = None t5_model = None mistral = None t5_loaded = False mistral_loaded = False # Root endpoint @app.get("/") async def root(): logger.info(f"Root endpoint called at {time.time()}") return JSONResponse( content={"message": "MGZON Smart Assistant is running"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) # Health check endpoint @app.get("/health") async def health_check(): logger.info(f"Health check endpoint called at {time.time()}") return JSONResponse( content={"status": "healthy" if t5_loaded else "loading"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) # Async function to load T5 model async def load_t5_model(): global t5_tokenizer, t5_model, t5_loaded start_time = time.time() logger.info(f"Starting T5 model loading at {start_time}") try: T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots") logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}") t5_tokenizer = AutoTokenizer.from_pretrained( T5_MODEL_PATH, local_files_only=True, torch_dtype="float16" # Reduce memory usage ) logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds") logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}") t5_model = AutoModelForSeq2SeqLM.from_pretrained( T5_MODEL_PATH, local_files_only=True, torch_dtype="float16" # Reduce memory usage ) logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds") t5_loaded = True except Exception as e: logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True) t5_loaded = False raise RuntimeError(f"Failed to load T5 model: {str(e)}") finally: end_time = time.time() logger.info(f"T5 model loading completed in {end_time - start_time} seconds") # Async function to load Mistral model async def load_mistral_model(): global mistral, mistral_loaded start_time = time.time() logger.info(f"Starting Mistral model loading at {start_time}") try: gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf") if not os.path.exists(gguf_path): logger.error(f"Mistral GGUF file not found at {gguf_path}") raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}") logger.info(f"Loading Mistral model from {gguf_path}") mistral = Llama( model_path=gguf_path, n_ctx=512, n_threads=1, n_batch=128, verbose=True ) logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds") mistral_loaded = True except Exception as e: logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True) mistral_loaded = False raise RuntimeError(f"Failed to load Mistral model: {str(e)}") finally: end_time = time.time() logger.info(f"Mistral model loading completed in {end_time - start_time} seconds") # Run T5 model loading in the background @app.on_event("startup") async def startup_event(): logger.info(f"Startup event triggered at {time.time()}") asyncio.create_task(load_t5_model()) # Load only T5 at startup # Define request schema class AskRequest(BaseModel): question: str max_new_tokens: int = 150 # Endpoint: /ask @app.post("/ask") async def ask(req: AskRequest): logger.info(f"Received ask request: {req.question} at {time.time()}") if not t5_loaded: logger.error("T5 model not loaded yet") raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later") q = req.question.strip() if not q: logger.error("Empty question received") raise HTTPException(status_code=400, detail="Empty question") try: if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]): # Use T5 model logger.info("Using MGZON-FLAN-T5 model") inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256) out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens) answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True) model_name = "MGZON-FLAN-T5" else: # Load Mistral model if not loaded if not mistral_loaded: logger.info("Mistral model not loaded, loading now...") await load_mistral_model() if not mistral_loaded: raise HTTPException(status_code=503, detail="Failed to load Mistral model") # Use Mistral model logger.info("Using Mistral-7B-GGUF model") out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7) answer = out["choices"][0]["text"].strip() model_name = "Mistral-7B-GGUF" logger.info(f"Response generated by {model_name}: {answer}") return {"model": model_name, "response": answer} except Exception as e: logger.error(f"Error processing request: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}") # Run the app if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=8080, log_level="info", workers=1, timeout_keep_alive=15, limit_concurrency=5, limit_max_requests=50 )