Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import tempfile | |
| import uuid | |
| import spaces | |
| from typing import Optional | |
| from backend import ConfigManager, ModelManager, InferenceEngine | |
| from backend.utils.metrics import create_accuracy_table, save_dataframe_to_csv | |
| class GradioApp: | |
| """Gradio application for InternVL3 prompt engineering.""" | |
| def __init__(self): | |
| """Initialize the Gradio application.""" | |
| # Initialize backend components | |
| self.config_manager = ConfigManager() | |
| self.model_manager = ModelManager(self.config_manager) | |
| self.inference_engine = InferenceEngine(self.model_manager, self.config_manager) | |
| # Try to preload default model | |
| try: | |
| self.model_manager.preload_default_model() | |
| print("β Default model preloaded successfully!") | |
| except Exception as e: | |
| print(f"β οΈ Default model preloading failed: {str(e)}") | |
| print("The model will be loaded when first needed.") | |
| def get_current_model_status(self) -> str: | |
| """Get current model status for display.""" | |
| return self.model_manager.get_current_model_status() | |
| def handle_stop_button(self): | |
| """Handle stop button click.""" | |
| message = self.inference_engine.set_stop_flag() | |
| return message, gr.update(visible=True) | |
| def on_model_change(self, model_selection: str, quantization_type: str) -> str: | |
| """Handle model/quantization dropdown changes.""" | |
| current_status = self.get_current_model_status() | |
| if model_selection and quantization_type: | |
| available_models = self.config_manager.get_available_models() | |
| target_id = available_models.get(model_selection) | |
| current_model_id = None | |
| if self.model_manager.current_model: | |
| current_model_id = self.model_manager.current_model.model_id | |
| if (current_model_id != target_id or | |
| (self.model_manager.current_model and | |
| self.model_manager.current_model.current_quantization != quantization_type)): | |
| return f"π Will load {model_selection} with {quantization_type} when processing starts" | |
| return current_status | |
| def get_model_choices_with_info(self) -> list[str]: | |
| """Get model choices with type information for dropdown.""" | |
| choices = [] | |
| for model_name in self.config_manager.get_available_models().keys(): | |
| model_config = self.config_manager.get_model_config(model_name) | |
| model_type = model_config.get('model_type', 'unknown').upper() | |
| choices.append(f"{model_name} ({model_type})") | |
| return choices | |
| def extract_model_name_from_choice(self, choice: str) -> str: | |
| """Extract the actual model name from the dropdown choice.""" | |
| return choice.split(' (')[0] if ' (' in choice else choice | |
| def update_image_preview(self, evt: gr.SelectData, df, folder_path): | |
| """Update image preview when table row is selected.""" | |
| if df is None or evt.index[0] >= len(df): | |
| return None, "" | |
| try: | |
| # Use the full dataframe with image paths | |
| full_df = getattr(self.inference_engine, 'full_df', None) | |
| if full_df is None or evt.index[0] >= len(full_df): | |
| return None, "" | |
| selected_row = full_df.iloc[evt.index[0]] | |
| image_path = selected_row["Image Path"] | |
| model_output = selected_row["Model Output"] | |
| if not os.path.exists(image_path): | |
| return None, model_output | |
| file_extension = Path(image_path).suffix | |
| temp_filename = f"gradio_preview_{uuid.uuid4().hex}{file_extension}" | |
| temp_path = os.path.join(tempfile.gettempdir(), temp_filename) | |
| shutil.copy2(image_path, temp_path) | |
| return temp_path, model_output | |
| except Exception as e: | |
| print(f"Error loading image preview: {e}") | |
| return None, "" | |
| def download_results_csv(self, results_table_data): | |
| """Download results as CSV file.""" | |
| try: | |
| print(f"Download function called with data type: {type(results_table_data)}") | |
| if results_table_data is None: | |
| print("No data to download") | |
| return None | |
| # Handle different data types from Gradio | |
| if hasattr(results_table_data, 'values'): | |
| # If it's a pandas DataFrame | |
| df = results_table_data | |
| elif isinstance(results_table_data, list): | |
| # If it's a list of lists or list of dicts | |
| if len(results_table_data) == 0: | |
| print("Empty data") | |
| return None | |
| df = pd.DataFrame(results_table_data, columns=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"]) | |
| else: | |
| # Try to convert to DataFrame | |
| df = pd.DataFrame(results_table_data) | |
| print(f"DataFrame shape: {df.shape}") | |
| print(f"DataFrame columns: {df.columns.tolist()}") | |
| # Create temporary file | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) | |
| df.to_csv(temp_file.name, index=False) | |
| temp_file.close() | |
| print(f"CSV file created: {temp_file.name}") | |
| return temp_file.name | |
| except Exception as e: | |
| print(f"Error in download_results_csv: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def submit_and_show_metrics(self, df): | |
| """Generate and show metrics for results.""" | |
| if df is None: | |
| return df, df, None, None, None, gr.update(visible=False), gr.update(visible=False), "" | |
| # Only create metrics if all outputs are valid yes/no responses | |
| try: | |
| metrics_df, cm_plot_path, cm_values = create_accuracy_table(df) | |
| return df, df, metrics_df, cm_plot_path, cm_values, gr.update(visible=True), gr.update(visible=True), "π Metrics calculated successfully!" | |
| except Exception as e: | |
| print(f"Could not create metrics: {str(e)}") | |
| return df, df, None, None, None, gr.update(visible=False), gr.update(visible=True), f"β οΈ Could not calculate metrics: {str(e)}" | |
| def process_input_ui(self, folder_path, prompt, quantization_type, model_selection): | |
| """UI wrapper for processing input with progress updates.""" | |
| if not folder_path or not prompt.strip(): | |
| return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), | |
| "Please upload a folder and enter a prompt.", None, None, None, | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(value="β οΈ Please upload a folder and enter a prompt.", visible=True), "", gr.update(visible=False)) | |
| # Extract actual model name from the dropdown choice | |
| actual_model_name = self.extract_model_name_from_choice(model_selection) | |
| # Check if model needs to be downloaded and show progress | |
| available_models = self.config_manager.get_available_models() | |
| model_id = available_models[actual_model_name] | |
| # Show processing message and hide stop status | |
| yield (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
| None, None, None, None, | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(value="π Initializing processing...", visible=True), prompt, gr.update(visible=False)) | |
| # Process the input | |
| error, show_results, show_image, table, error_message, final_message = self.inference_engine.process_folder_input( | |
| folder_path, prompt, quantization_type, actual_model_name, gr.Progress() | |
| ) | |
| # If error is visible, show results section but keep error visible | |
| if error["visible"]: | |
| yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), | |
| error, None, None, None, | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(value=final_message, visible=True), prompt, gr.update(visible=False)) | |
| else: | |
| yield (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), | |
| None, show_results, show_image, table, | |
| gr.update(visible=True), gr.update(visible=False), | |
| gr.update(value=final_message, visible=True), prompt, gr.update(visible=False)) | |
| def rerun_ui(self, df, new_prompt, quantization_type, model_selection): | |
| """UI wrapper for rerun with progress updates.""" | |
| if df is None or not new_prompt.strip(): | |
| return (df, None, None, None, | |
| gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(visible=False), gr.update(visible=True), "β οΈ Please provide a valid prompt", "") | |
| # Extract actual model name from the dropdown choice | |
| actual_model_name = self.extract_model_name_from_choice(model_selection) | |
| # Hide all sections and show only processing, clear model output display | |
| yield (df, None, None, None, | |
| gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(visible=False), gr.update(visible=True), "π Initializing reprocessing...", "Select a row from the table to see model output...") | |
| # Process with new prompt | |
| updated_df, accuracy_table_data, cm_plot, cm_values, section4_vis, progress_vis, final_message = self.inference_engine.rerun_with_new_prompt( | |
| df, new_prompt, quantization_type, actual_model_name, gr.Progress() | |
| ) | |
| # Show prompt editing and results sections again, show Generate Metrics button, hide progress, and clear model output display | |
| yield (updated_df, accuracy_table_data, cm_plot, cm_values, | |
| gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), section4_vis, | |
| gr.update(visible=True), gr.update(visible=False), final_message, "Select a row from the table to see updated model output...") | |
| def create_interface(self): | |
| """Create and return the Gradio interface.""" | |
| # CSS from original app.py | |
| css = """ | |
| .progress { | |
| margin: 15px 0; | |
| padding: 20px; | |
| border-radius: 12px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border: none; | |
| color: white; | |
| font-weight: 600; | |
| font-size: 16px; | |
| text-align: center; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| animation: progressPulse 2s ease-in-out infinite alternate; | |
| } | |
| @keyframes progressPulse { | |
| 0% { | |
| transform: scale(1); | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| } | |
| 100% { | |
| transform: scale(1.02); | |
| box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4); | |
| } | |
| } | |
| .processing { | |
| background: linear-gradient(45deg, #f0f9ff, #e3f2fd); | |
| border: 2px solid #1976d2; | |
| border-radius: 10px; | |
| padding: 20px; | |
| text-align: center; | |
| margin: 10px 0; | |
| } | |
| .gr-button.processing { | |
| background-color: #ffa726 !important; | |
| color: white !important; | |
| pointer-events: none; | |
| } | |
| /* Stop button styling */ | |
| .stop-button { | |
| background: linear-gradient(135deg, #ff4757 0%, #c44569 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| font-size: 16px !important; | |
| box-shadow: 0 4px 15px rgba(255, 71, 87, 0.4) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .stop-button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 8px 25px rgba(255, 71, 87, 0.6) !important; | |
| background: linear-gradient(135deg, #ff3742 0%, #b83754 100%) !important; | |
| } | |
| .stop-status { | |
| color: #ff4757; | |
| font-weight: 600; | |
| background: rgba(255, 71, 87, 0.1); | |
| padding: 10px; | |
| border-radius: 8px; | |
| border-left: 4px solid #ff4757; | |
| margin: 10px 0; | |
| } | |
| /* Enhanced button styling */ | |
| .gr-button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border: none; | |
| border-radius: 8px; | |
| color: white; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| } | |
| .gr-button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4); | |
| } | |
| """ | |
| with gr.Blocks(theme="origin", css=css) as demo: | |
| gr.Markdown(""" | |
| <h1 style='text-align:center; color:#1976d2; font-size:2.5em; font-weight:bold; margin-bottom:40px!important;'>PROMPT_PILOT</h1> | |
| <p style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'> | |
| π€ AI-powered analysis with different vision models | |
| </p> | |
| <h2 style='text-align:center; color:#666; font-size:1.1em; margin-bottom:30px;'> | |
| Note: Currently Accuracy only works properly in case of binary output. For other cases kindly download the csv and calculate the accuracy separately. | |
| </h2> | |
| """, elem_id="main-title") | |
| # Model and Quantization selection dropdowns at the top | |
| model_choices = self.get_model_choices_with_info() | |
| default_choice = f"{self.config_manager.get_default_model()} (INTERNVL)" | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=model_choices, | |
| value=default_choice, | |
| label="π€ Model Selection", | |
| info="Select model: InternVL (vision+text), Qwen (text-only)", | |
| elem_id="model-dropdown" | |
| ) | |
| quantization_dropdown = gr.Dropdown( | |
| choices=["quantized(8bit)", "non-quantized(fp16)"], | |
| value="non-quantized(fp16)", | |
| label="π§ Model Quantization", | |
| info="Select quantization type: quantized (8bit) uses less memory, non-quantized (fp16) for better quality", | |
| elem_id="quantization-dropdown" | |
| ) | |
| # Model status indicator | |
| with gr.Row(): | |
| model_status = gr.Markdown( | |
| value=self.get_current_model_status(), | |
| label="Model Status", | |
| elem_classes=["model-status"] | |
| ) | |
| # Stop button row | |
| with gr.Row(): | |
| stop_btn = gr.Button("π STOP PROCESSING", variant="stop", size="lg", elem_classes=["stop-button"]) | |
| stop_status = gr.Markdown("", elem_classes=["stop-status"], visible=False) | |
| with gr.Row(visible=True) as section1_row: | |
| with gr.Column(): | |
| folder_input = gr.File( | |
| label="Upload Folder", | |
| file_count="directory", | |
| type="filepath" | |
| ) | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt here", | |
| placeholder="Type your prompt...", | |
| lines=3 | |
| ) | |
| with gr.Column(): | |
| submit_btn = gr.Button("Proceed", variant="primary") | |
| # Progress indicator for section 1 | |
| with gr.Row(visible=True) as section1_progress_row: | |
| section1_progress_message = gr.Markdown("", elem_classes=["progress"], visible=False) | |
| # Section 2: Edit Prompt and Rerun Controls (separate section) | |
| with gr.Row(visible=False) as section2_prompt_row: | |
| with gr.Column(): | |
| with gr.Row(): | |
| prompt_input_section2 = gr.Textbox( | |
| label="Edit Prompt", | |
| placeholder="Modify your prompt here...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| rerun_btn = gr.Button("π Rerun", variant="secondary", size="lg", scale=1) | |
| # Section 3: Results Display | |
| with gr.Row(visible=False) as section3_results_row: | |
| error_message = gr.Textbox(label="Error Message", visible=False) | |
| with gr.Column(scale=1): | |
| image_preview = gr.Image(label="Selected Image", height=270, width=480) | |
| model_output_display = gr.Textbox( | |
| label="Model Output for Selected Image", | |
| placeholder="Select a row from the table to see model output...", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| gr.HTML("") # Empty space to push button to right | |
| download_results_btn = gr.Button("π₯ CSV", size="sm", scale=1) | |
| results_csv_output = gr.File(label="", visible=True, scale=1, show_label=False) | |
| results_table = gr.Dataframe( | |
| headers=["S.No", "Image Name", "Ground Truth", "Binary Output", "Model Output"], | |
| label="Results", | |
| interactive=True, # Make it editable for ground truth input | |
| col_count=(5, "fixed") | |
| ) | |
| # Generate Metrics button | |
| with gr.Row(visible=False) as section3_submit_row: | |
| with gr.Column(): | |
| submit_results_btn = gr.Button("Generate Metrics", variant="primary", size="lg") | |
| # Progress indicator row | |
| with gr.Row(visible=False) as progress_row: | |
| progress_message = gr.Markdown("", elem_classes=["progress"]) | |
| # Section 4: Metrics and confusion matrix | |
| with gr.Row(visible=False) as section4_metrics_row: | |
| with gr.Column(scale=2): | |
| confusion_matrix_plot = gr.Image( | |
| label="Confusion Matrix" | |
| ) | |
| with gr.Column(scale=2): | |
| accuracy_table = gr.Dataframe( | |
| label="Performance Metrics", | |
| interactive=False | |
| ) | |
| confusion_matrix_table = gr.Dataframe( | |
| label="Confusion Matrix Table", | |
| interactive=False | |
| ) | |
| # State to store folder path | |
| folder_path_state = gr.State() | |
| folder_input.change( | |
| fn=lambda x: x, | |
| inputs=[folder_input], | |
| outputs=[folder_path_state] | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=self.process_input_ui, | |
| inputs=[folder_input, prompt_input, quantization_dropdown, model_dropdown], | |
| outputs=[section1_row, section2_prompt_row, section3_results_row, error_message, results_table, image_preview, results_table, section3_submit_row, section4_metrics_row, section1_progress_message, prompt_input_section2, stop_status] | |
| ) | |
| results_table.select( | |
| fn=self.update_image_preview, | |
| inputs=[results_table, folder_path_state], | |
| outputs=[image_preview, model_output_display] | |
| ) | |
| submit_results_btn.click( | |
| fn=self.submit_and_show_metrics, | |
| inputs=[results_table], | |
| outputs=[results_table, results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table, section4_metrics_row, progress_row, progress_message] | |
| ) | |
| download_results_btn.click( | |
| fn=self.download_results_csv, | |
| inputs=[results_table], | |
| outputs=[results_csv_output] | |
| ) | |
| rerun_btn.click( | |
| fn=self.rerun_ui, | |
| inputs=[results_table, prompt_input_section2, quantization_dropdown, model_dropdown], | |
| outputs=[results_table, accuracy_table, confusion_matrix_plot, confusion_matrix_table, | |
| section1_row, section2_prompt_row, section3_results_row, section4_metrics_row, section3_submit_row, progress_row, progress_message, model_output_display] | |
| ) | |
| # Model change handler to update status | |
| model_dropdown.change( | |
| fn=self.on_model_change, | |
| inputs=[model_dropdown, quantization_dropdown], | |
| outputs=[model_status] | |
| ) | |
| quantization_dropdown.change( | |
| fn=self.on_model_change, | |
| inputs=[model_dropdown, quantization_dropdown], | |
| outputs=[model_status] | |
| ) | |
| # Stop button click handler | |
| stop_btn.click( | |
| fn=self.handle_stop_button, | |
| inputs=[], | |
| outputs=[stop_status, stop_status] | |
| ) | |
| return demo | |
| def launch(self, **kwargs): | |
| """Launch the Gradio application.""" | |
| demo = self.create_interface() | |
| return demo.launch(**kwargs) |