|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from utils import compare_embeddings |
|
|
|
|
|
from models.esm_2_650m import get_embedding as get_embedding_esm_2_650m |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
class CompareRequest(BaseModel): |
|
|
sequence_1: str |
|
|
sequence_2: str |
|
|
model: str = "esm_2_650m" |
|
|
|
|
|
|
|
|
model_mapping = {"esm_2_650m": get_embedding_esm_2_650m} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return { |
|
|
"message": "API is running. Use POST /compare to compare protein sequences." |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/compare") |
|
|
def compare(request: CompareRequest): |
|
|
model = request.model |
|
|
|
|
|
if model not in model_mapping: |
|
|
return {"error": "Model not supported"} |
|
|
|
|
|
emb1 = model_mapping[model](request.sequence_1) |
|
|
emb2 = model_mapping[model](request.sequence_2) |
|
|
|
|
|
similarity, classification = compare_embeddings(emb1, emb2) |
|
|
|
|
|
return { |
|
|
"cosine_similarity": float(similarity), |
|
|
"classification": classification, |
|
|
"model": model, |
|
|
} |
|
|
|