blackshadow1 commited on
Commit
363e473
·
verified ·
1 Parent(s): 721652b

updated the code ✅✅

Browse files
Files changed (1) hide show
  1. app.py +124 -48
app.py CHANGED
@@ -1,60 +1,136 @@
1
  import gradio as gr
2
- from transformers import T5ForConditionalGeneration, AutoTokenizer
3
-
4
- # Load the T5 model and tokenizer
5
- model = T5ForConditionalGeneration.from_pretrained("t5-small")
6
- tokenizer = AutoTokenizer.from_pretrained("t5-small")
7
-
8
- def generate_response(image, text, specialization):
9
- """
10
- Generate a response based on the uploaded image and/or user query.
11
- """
12
- if not text.strip():
13
- return "Please enter a valid query."
14
-
15
- # Prepare the input text
16
- input_text = f"Specialization: {specialization}\n{text}"
17
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
18
-
19
- # Generate the response
20
- outputs = model.generate(**inputs, max_new_tokens=128)
21
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
-
23
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Gradio Interface
26
- with gr.Blocks() as app:
27
- gr.Markdown("# 🩺 AI Doctor Assistant")
28
- gr.Markdown(
29
- "Upload a medical image (e.g., X-ray, MRI) and/or ask a health-related question. "
30
- "Select a specialization for more accurate responses."
31
- )
32
 
33
  with gr.Row():
34
- image_input = gr.Image(type="pil", label="Upload Medical Image (Optional)")
35
- specialization = gr.Dropdown(
36
- choices=["general", "radiology", "cardiology", "neurology", "pediatrics"],
37
- value="general",
38
- label="Specialization"
39
- )
 
 
 
 
 
 
 
 
 
 
40
 
41
- text_input = gr.Textbox(
42
- label="Enter Your Query",
43
- placeholder="Type your health-related question here..."
44
- )
45
 
46
- response_output = gr.Textbox(
47
- label="AI Doctor's Response",
48
- interactive=False
49
- )
50
 
51
- submit_button = gr.Button("Get Response")
 
 
 
52
 
53
- submit_button.click(
 
54
  fn=generate_response,
55
- inputs=[image_input, text_input, specialization],
56
- outputs=[response_output]
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
- # Launch the app
60
- app.launch()
 
1
  import gradio as gr
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
+ import torch
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Load model and processor with cache
11
+ @gr.Cache()
12
+ def load_model():
13
+ try:
14
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
15
+ "prithivMLmods/Radiology-Infer-Mini",
16
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
+ device_map="auto"
18
+ )
19
+ processor = AutoProcessor.from_pretrained("prithivMLmods/Radiology-Infer-Mini")
20
+ return model, processor
21
+ except Exception as e:
22
+ logger.error(f"Error loading model: {e}")
23
+ raise
24
+
25
+ model, processor = load_model()
26
+
27
+ def generate_response(image, text, specialization, history=None):
28
+ """Generate response combining image and text inputs with proper error handling."""
29
+ try:
30
+ # Validate inputs
31
+ if not image and not text.strip():
32
+ return "⚠️ Please provide either an image or a text query.", history
33
+
34
+ # Prepare messages with specialization context
35
+ messages = [{
36
+ "role": "system",
37
+ "content": f"You are a {specialization} medical assistant. Provide professional, accurate analysis in clear language."
38
+ }]
39
+
40
+ content = []
41
+ if image:
42
+ content.append({"type": "image", "image": image})
43
+ if text.strip():
44
+ content.append({"type": "text", "text": text})
45
+
46
+ messages.append({"role": "user", "content": content})
47
+
48
+ # Process inputs
49
+ inputs = processor(
50
+ text=processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
51
+ images=[image] if image else None,
52
+ return_tensors="pt"
53
+ ).to(model.device)
54
+
55
+ # Generate response
56
+ generated_ids = model.generate(
57
+ **inputs,
58
+ max_new_tokens=256,
59
+ temperature=0.7,
60
+ repetition_penalty=1.1,
61
+ eos_token_id=processor.tokenizer.eos_token_id
62
+ )
63
+
64
+ response = processor.batch_decode(
65
+ generated_ids[:, inputs.input_ids.shape[1]:],
66
+ skip_special_tokens=True,
67
+ clean_up_tokenization_spaces=False
68
+ )[0]
69
+
70
+ # Format history
71
+ formatted_history = (history or []) + [(text, response)]
72
+ return response, formatted_history
73
+
74
+ except Exception as e:
75
+ logger.error(f"Generation error: {e}")
76
+ return f"❌ Error processing request: {str(e)}", history
77
 
78
  # Gradio Interface
79
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
80
+ gr.Markdown("""
81
+ # 🩺 AI Medical Assistant
82
+ **Upload medical images and ask questions** for analysis in various specialties.
83
+ """)
 
84
 
85
  with gr.Row():
86
+ with gr.Column(scale=1):
87
+ specialization = gr.Dropdown(
88
+ label="Medical Specialty",
89
+ choices=["General Practice", "Radiology", "Cardiology", "Neurology", "Pediatrics"],
90
+ value="General Practice"
91
+ )
92
+ image_input = gr.Image(type="pil", label="Upload Medical Image", sources=["upload", "clipboard"])
93
+
94
+ with gr.Column(scale=2):
95
+ chatbot = gr.Chatbot(height=400, label="Consultation History")
96
+ text_input = gr.Textbox(
97
+ label="Patient Query",
98
+ placeholder="Describe symptoms or ask about the image...",
99
+ lines=3
100
+ )
101
+ submit_btn = gr.Button("Submit", variant="primary")
102
 
103
+ # Examples
104
+ gr.Examples(
105
+ examples=[
106
+ ["Radiology", "Explain this chest X-ray finding.", "![image/png](https://cdn-uploads.huggingface.co/production/uploads/669d65ab8580d17cb6beabc0/eg2_EACmEhtkFqp3alrbm.png)
107
 
108
+ "],
109
+ ["Cardiology", "Interpret these ECG results.", "![image/png](https://cdn-uploads.huggingface.co/production/uploads/669d65ab8580d17cb6beabc0/jOmU2DXOnnU0jivCA6xWI.png)
 
 
110
 
111
+ "]
112
+ ],
113
+ inputs=[specialization, text_input, image_input]
114
+ )
115
 
116
+ # Event handling
117
+ submit_event = text_input.submit(
118
  fn=generate_response,
119
+ inputs=[image_input, text_input, specialization, chatbot],
120
+ outputs=[text_input, chatbot]
121
+ )
122
+
123
+ submit_btn.click(
124
+ fn=generate_response,
125
+ inputs=[image_input, text_input, specialization, chatbot],
126
+ outputs=[text_input, chatbot]
127
+ )
128
+
129
+ # Clear inputs after submission
130
+ submit_event.then(
131
+ lambda: ("", None),
132
+ outputs=[text_input, image_input]
133
  )
134
 
135
+ if __name__ == "__main__":
136
+ app.launch(server_name="0.0.0.0" if torch.cuda.is_available() else "127.0.0.0")