api-mg / app.py
ibrahimlasfar's picture
update for cache
431e7f9
raw
history blame
3.45 kB
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from llama_cpp import Llama
# -------------------------------------------------
# إعداد مسار الـ cache
# -------------------------------------------------
CACHE_DIR = "/app/.cache/huggingface" # مسار موحد لـ Hugging Face Spaces
os.makedirs(CACHE_DIR, exist_ok=True)
# تأكد من أن المكتبتين تقرأ المتغيّرات البيئية
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
# -------------------------------------------------
# إنشاء التطبيق
# -------------------------------------------------
app = FastAPI(
title="MGZON Smart Assistant",
description="دمج نموذج T5 المدرب مع Mistral‑7B (GGUF) داخل Space"
)
# -------------------------------------------------
# 1️⃣ تحميل نموذج T5 المدرب من Hub
# -------------------------------------------------
T5_REPO = "MGZON/mgzon-flan-t5-base"
try:
t5_tokenizer = AutoTokenizer.from_pretrained(T5_REPO, cache_dir=CACHE_DIR)
t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_REPO, cache_dir=CACHE_DIR)
except Exception as e:
raise RuntimeError(f"فشل تحميل نموذج T5 من {T5_REPO}: {str(e)}")
# -------------------------------------------------
# 2️⃣ تحميل ملف Mistral .gguf
# -------------------------------------------------
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q4_K_M.gguf")
if not os.path.exists(gguf_path):
raise RuntimeError(
f"ملف Mistral .gguf غير موجود في {gguf_path}. "
"تأكد من أن ملف setup.sh تم تنفيذه أثناء الـ build."
)
try:
mistral = Llama(
model_path=gguf_path,
n_ctx=2048,
n_threads=8,
# إذا كان لديك GPU، يمكنك إضافة: n_gpu_layers=35
)
except Exception as e:
raise RuntimeError(f"فشل تحميل نموذج Mistral من {gguf_path}: {str(e)}")
# -------------------------------------------------
# تعريف شكل الطلب (JSON)
# -------------------------------------------------
class AskRequest(BaseModel):
question: str
max_new_tokens: int = 150
# -------------------------------------------------
# نقطة النهاية /ask
# -------------------------------------------------
@app.post("/ask")
def ask(req: AskRequest):
q = req.question.strip()
if not q:
raise HTTPException(status_code=400, detail="Empty question")
# منطق اختيار النموذج
try:
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
# نموذج T5
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
model_name = "MGZON-FLAN-T5"
else:
# نموذج Mistral
out = mistral(prompt=q, max_tokens=req.max_new_tokens)
answer = out["choices"][0]["text"].strip()
model_name = "Mistral-7B-GGUF"
return {"model": model_name, "response": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")