File size: 2,441 Bytes
1f9b706
10afc48
4cfd29e
 
1f9b706
4cfd29e
10afc48
4cfd29e
bf46f16
10afc48
 
94ddf63
4cfd29e
10afc48
 
bf46f16
10afc48
 
 
 
4cfd29e
 
 
10afc48
 
 
4cfd29e
10afc48
4cfd29e
10afc48
4cfd29e
10afc48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94ddf63
4cfd29e
10afc48
4cfd29e
10afc48
4cfd29e
10afc48
 
 
4cfd29e
1f9b706
 
4cfd29e
 
10afc48
4cfd29e
 
10afc48
1f9b706
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from transformers import AutoProcessor, AutoModel, AutoTokenizer, pipeline
from PIL import Image
import torch

# ---------------------------
# Load encoder (MedSigLIP)
# ---------------------------
ENCODER_ID = "fokan/medsiglip-448-int8"
encoder_processor = AutoProcessor.from_pretrained(ENCODER_ID)
encoder_model = AutoModel.from_pretrained(ENCODER_ID).eval()

# ---------------------------
# Load decoder (MedGemma)
# ---------------------------
DECODER_ID = "fokan/medgemma-4b-it-int8"
decoder = pipeline("text-generation", model=DECODER_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)

# ---------------------------
# Core Function
# ---------------------------
@torch.no_grad()
def analyze_xray(image):
    # Step 1: Encode the image into embedding
    inputs = encoder_processor(images=image, text=["chest x-ray"], return_tensors="pt", padding=True)
    outputs = encoder_model(**inputs)
    if hasattr(outputs, "image_embeds"):
        embedding = outputs.image_embeds[0]
    elif hasattr(outputs, "last_hidden_state"):
        embedding = outputs.last_hidden_state.mean(dim=1)[0]
    else:
        embedding = list(outputs.values())[0].mean(dim=1)[0]
    embedding = embedding / embedding.norm()

    # Step 2: Generate a short diagnostic report
    prompt = (
        "You are a radiologist. Analyze this chest X-ray embedding vector and describe any possible findings, "
        "anomalies, or impressions as a short professional report.\n"
        f"<embedding>{embedding[:256].tolist()}</embedding>"
    )

    report = decoder(prompt, max_new_tokens=180, temperature=0.8, top_p=0.9)[0]["generated_text"]

    # Return both embedding preview + text
    preview = embedding[:5].tolist()
    return f"✅ Embedding (preview): {preview}\n\n🩺 **AI Radiology Report:**\n{report}"

# ---------------------------
# Gradio UI
# ---------------------------
title = "🩻 MedSigLIP → MedGemma Fusion"
desc = """
Upload an **X-ray image**, and this demo will:
1. Extract its visual embedding using `fokan/medsiglip-448-fp16-pruned20`.
2. Generate a **radiology-style report** using `fokan/medgemma-4b-it-fp16-pruned20`.
"""

demo = gr.Interface(
    fn=analyze_xray,
    inputs=gr.Image(type="pil", label="Upload X-ray Image"),
    outputs=gr.Markdown(label="AI Report"),
    title=title,
    description=desc,
    theme="gradio/soft",
)

if __name__ == "__main__":
    demo.launch()