Spaces:
Runtime error
Runtime error
File size: 2,441 Bytes
1f9b706 10afc48 4cfd29e 1f9b706 4cfd29e 10afc48 4cfd29e bf46f16 10afc48 94ddf63 4cfd29e 10afc48 bf46f16 10afc48 4cfd29e 10afc48 4cfd29e 10afc48 4cfd29e 10afc48 4cfd29e 10afc48 94ddf63 4cfd29e 10afc48 4cfd29e 10afc48 4cfd29e 10afc48 4cfd29e 1f9b706 4cfd29e 10afc48 4cfd29e 10afc48 1f9b706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
from transformers import AutoProcessor, AutoModel, AutoTokenizer, pipeline
from PIL import Image
import torch
# ---------------------------
# Load encoder (MedSigLIP)
# ---------------------------
ENCODER_ID = "fokan/medsiglip-448-int8"
encoder_processor = AutoProcessor.from_pretrained(ENCODER_ID)
encoder_model = AutoModel.from_pretrained(ENCODER_ID).eval()
# ---------------------------
# Load decoder (MedGemma)
# ---------------------------
DECODER_ID = "fokan/medgemma-4b-it-int8"
decoder = pipeline("text-generation", model=DECODER_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
# ---------------------------
# Core Function
# ---------------------------
@torch.no_grad()
def analyze_xray(image):
# Step 1: Encode the image into embedding
inputs = encoder_processor(images=image, text=["chest x-ray"], return_tensors="pt", padding=True)
outputs = encoder_model(**inputs)
if hasattr(outputs, "image_embeds"):
embedding = outputs.image_embeds[0]
elif hasattr(outputs, "last_hidden_state"):
embedding = outputs.last_hidden_state.mean(dim=1)[0]
else:
embedding = list(outputs.values())[0].mean(dim=1)[0]
embedding = embedding / embedding.norm()
# Step 2: Generate a short diagnostic report
prompt = (
"You are a radiologist. Analyze this chest X-ray embedding vector and describe any possible findings, "
"anomalies, or impressions as a short professional report.\n"
f"<embedding>{embedding[:256].tolist()}</embedding>"
)
report = decoder(prompt, max_new_tokens=180, temperature=0.8, top_p=0.9)[0]["generated_text"]
# Return both embedding preview + text
preview = embedding[:5].tolist()
return f"✅ Embedding (preview): {preview}\n\n🩺 **AI Radiology Report:**\n{report}"
# ---------------------------
# Gradio UI
# ---------------------------
title = "🩻 MedSigLIP → MedGemma Fusion"
desc = """
Upload an **X-ray image**, and this demo will:
1. Extract its visual embedding using `fokan/medsiglip-448-fp16-pruned20`.
2. Generate a **radiology-style report** using `fokan/medgemma-4b-it-fp16-pruned20`.
"""
demo = gr.Interface(
fn=analyze_xray,
inputs=gr.Image(type="pil", label="Upload X-ray Image"),
outputs=gr.Markdown(label="AI Report"),
title=title,
description=desc,
theme="gradio/soft",
)
if __name__ == "__main__":
demo.launch()
|