import gradio as gr from transformers import AutoProcessor, AutoModel, AutoTokenizer, pipeline from PIL import Image import torch # --------------------------- # Load encoder (MedSigLIP) # --------------------------- ENCODER_ID = "fokan/medsiglip-448-int8" encoder_processor = AutoProcessor.from_pretrained(ENCODER_ID) encoder_model = AutoModel.from_pretrained(ENCODER_ID).eval() # --------------------------- # Load decoder (MedGemma) # --------------------------- DECODER_ID = "fokan/medgemma-4b-it-int8" decoder = pipeline("text-generation", model=DECODER_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) # --------------------------- # Core Function # --------------------------- @torch.no_grad() def analyze_xray(image): # Step 1: Encode the image into embedding inputs = encoder_processor(images=image, text=["chest x-ray"], return_tensors="pt", padding=True) outputs = encoder_model(**inputs) if hasattr(outputs, "image_embeds"): embedding = outputs.image_embeds[0] elif hasattr(outputs, "last_hidden_state"): embedding = outputs.last_hidden_state.mean(dim=1)[0] else: embedding = list(outputs.values())[0].mean(dim=1)[0] embedding = embedding / embedding.norm() # Step 2: Generate a short diagnostic report prompt = ( "You are a radiologist. Analyze this chest X-ray embedding vector and describe any possible findings, " "anomalies, or impressions as a short professional report.\n" f"{embedding[:256].tolist()}" ) report = decoder(prompt, max_new_tokens=180, temperature=0.8, top_p=0.9)[0]["generated_text"] # Return both embedding preview + text preview = embedding[:5].tolist() return f"✅ Embedding (preview): {preview}\n\n🩺 **AI Radiology Report:**\n{report}" # --------------------------- # Gradio UI # --------------------------- title = "🩻 MedSigLIP → MedGemma Fusion" desc = """ Upload an **X-ray image**, and this demo will: 1. Extract its visual embedding using `fokan/medsiglip-448-fp16-pruned20`. 2. Generate a **radiology-style report** using `fokan/medgemma-4b-it-fp16-pruned20`. """ demo = gr.Interface( fn=analyze_xray, inputs=gr.Image(type="pil", label="Upload X-ray Image"), outputs=gr.Markdown(label="AI Report"), title=title, description=desc, theme="gradio/soft", ) if __name__ == "__main__": demo.launch()