import os import torch from fastapi import FastAPI from pydantic import BaseModel, Field from comet import load_from_checkpoint from huggingface_hub import snapshot_download, HfApi # ========================================================== # 🚀 Configuração da API # ========================================================== app = FastAPI( title="COMETKiwi-DA-XL API", version="1.0.0", description="API para avaliação de traduções usando Unbabel/wmt23-cometkiwi-da-xl " "(modelo de Quality Estimation sem referência)." ) MODEL_NAME = "Unbabel/wmt23-cometkiwi-da-xl" HF_TOKEN = os.environ.get("HF_TOKEN") SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/cometkiwi-xl") # Diretório local de cache (gravável no Space) MODEL_DIR = os.path.join(os.path.dirname(__file__), "model") MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt") # ========================================================== # ⚙️ Função auxiliar: garante que o modelo está disponível # ========================================================== def ensure_model_persisted_once(): """ Se o modelo ainda não estiver presente em ./model, faz download e tenta persistir a pasta dentro do Space. """ # Se o modelo anterior (XCOMET) existir, apaga if os.path.exists(MODEL_DIR): # Verifica se é outro modelo if not os.path.exists(MODEL_CKPT): print("🧹 Limpando cache antigo de modelo...") import shutil shutil.rmtree(MODEL_DIR, ignore_errors=True) # Se o checkpoint já existir, pula o download if os.path.exists(MODEL_CKPT): print(f"✅ Modelo já encontrado em {MODEL_CKPT}. Pulando download.") return print("🔽 Baixando modelo Unbabel/wmt23-cometkiwi-da-xl para ./model ...") snapshot_download( repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=MODEL_DIR, local_dir_use_symlinks=False ) assert os.path.exists(MODEL_CKPT), f"Checkpoint não encontrado: {MODEL_CKPT}" # Tenta persistir o modelo no próprio Space try: print("⬆️ Enviando pasta 'model/' para o repositório do Space ...") api = HfApi(token=HF_TOKEN) api.upload_folder( repo_id=SPACE_REPO_ID, repo_type="space", folder_path=MODEL_DIR, path_in_repo="model", commit_message="Persistência automática do modelo COMETKiwi-DA-XL" ) print("✅ Modelo persistido com sucesso no Space.") except Exception as e: print(f"⚠️ Falha ao persistir modelo no Space: {e}") print(" Prosseguindo com o modelo local.") # ========================================================== # 📦 Inicialização do modelo # ========================================================== ensure_model_persisted_once() print(f"📂 Carregando modelo de {MODEL_CKPT} ...") model = load_from_checkpoint(MODEL_CKPT) print("✅ Modelo COMETKiwi-DA-XL carregado com sucesso!") USE_GPU = 1 if torch.cuda.is_available() else 0 print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}") # ========================================================== # 🧠 Estrutura dos dados de entrada # ========================================================== class TranslationPair(BaseModel): source: str = Field(alias="source", description="Texto original (source)") target: str = Field(alias="target", description="Tradução automática (machine translation)") class Config: allow_population_by_field_name = True # ========================================================== # 🔧 Função utilitária # ========================================================== def prepare_data(pairs: list[TranslationPair]): """ Converte lista de TranslationPair no formato esperado pelo COMET: [{"src": ..., "mt": ..., "ref": ...}, ...] """ data = [] for p in pairs: src = str(p.source).strip() mt = str(p.target).strip() item = {"src": src, "mt": mt} data.append(item) return data # ========================================================== # 🌐 Endpoints # ========================================================== @app.get("/") def root(): return { "message": "🚀 COMETKiwi-DA-XL API ativa e pronta para uso!", "gpu_enabled": bool(USE_GPU), "available_endpoints": ["/score", "/score_batch"] } @app.post("/score") def score_single(pair: TranslationPair): """ Avalia um único par de tradução (source → target) com COMETKiwi-DA-XL. """ try: data = [{"src": pair.source.strip(), "mt": pair.target.strip()}] output = model.predict(data, batch_size=8, gpus=USE_GPU) return { "system_score": getattr(output, "system_score", None), "segment_scores": getattr(output, "scores", None), "metadata": getattr(output, "metadata", None) } except Exception as e: print(f"❌ Erro em /score: {e}") return {"error": str(e)} @app.post("/score_batch") def score_batch(pairs: list[TranslationPair]): """ Avalia múltiplos pares de tradução em lote. """ try: data = prepare_data(pairs) output = model.predict(data, batch_size=8, gpus=USE_GPU) return { "system_score": getattr(output, "system_score", None), "segment_scores": getattr(output, "scores", None), } except Exception as e: print(f"❌ Erro em /score_batch: {e}") return {"error": str(e)} # ========================================================== # ▶️ Execução local (para debug) # ========================================================== if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)