""" Gradio app for PR Reviewer Assignment Model. This application provides an interactive interface for predicting PR reviewers based on PR title and modified files using a fine-tuned DeBERTa model. For private models, set the HF_TOKEN environment variable: export HF_TOKEN=your_huggingface_token """ import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import json import os # Model configuration MODEL_NAME = ( "yazoniak/pr-assignee-reviewer-deberta" # Update with your actual model name ) MAX_LENGTH = 8192 DEFAULT_THRESHOLD = 0.5 # Authentication token for private models HF_TOKEN = os.environ.get("HF_TOKEN", None) def load_model(): """ Load the model and tokenizer. For private models, requires HF_TOKEN environment variable to be set. Returns: tuple: (model, tokenizer, id2label) """ if HF_TOKEN: print(f"Using authentication token for private model: {MODEL_NAME}") model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, token=HF_TOKEN ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) else: print(f"No token found, attempting to load public model: {MODEL_NAME}") print("If this is a private model, set HF_TOKEN environment variable") model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model.eval() # Get label mappings from model config id2label = model.config.id2label if id2label and isinstance(list(id2label.keys())[0], str): id2label = {int(k): v for k, v in id2label.items()} return model, tokenizer, id2label # Load model at startup print("Loading model...") model, tokenizer, id2label = load_model() print(f"Model loaded successfully with {len(id2label)} reviewers") def predict_reviewers( pr_title: str, files_input: str, threshold: float = DEFAULT_THRESHOLD, custom_mapping: str = "", ) -> tuple[str, str]: """ Predict reviewers for a PR based on title and modified files. Args: pr_title: The PR title/description files_input: Comma or semicolon separated list of modified files threshold: Prediction threshold (0-1) custom_mapping: Optional JSON mapping of label IDs to names Returns: tuple: (formatted_predictions, all_scores_json) """ # Validate inputs if not pr_title or not pr_title.strip(): return "⚠️ Please enter a PR title", "" if not files_input or not files_input.strip(): return "⚠️ Please enter at least one file", "" # Parse files list files_list = [] for separator in [",", ";"]: if separator in files_input: files_list = [f.strip() for f in files_input.split(separator) if f.strip()] break if not files_list: files_list = [files_input.strip()] # Validate threshold if threshold < 0 or threshold > 1: return "⚠️ Threshold must be between 0 and 1", "" # Parse custom mapping if provided label_mapping = id2label # Default to model's labels if custom_mapping and custom_mapping.strip(): try: parsed_mapping = json.loads(custom_mapping) # Convert string keys to integers label_mapping = {int(k): v for k, v in parsed_mapping.items()} except json.JSONDecodeError: return "⚠️ Invalid JSON format for custom mapping", "" except (ValueError, TypeError): return "⚠️ Custom mapping must have numeric keys", "" # Format input for the model files_text = f"files: {', '.join(files_list)}" # Tokenize inputs = tokenizer( [pr_title], text_pair=[files_text], truncation=True, max_length=MAX_LENGTH, padding=True, return_tensors="pt", ) # Make prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.sigmoid(logits).numpy()[0] # Get predictions above threshold predicted_reviewers = [] all_scores = {} for idx, prob in enumerate(probabilities): reviewer_name = label_mapping.get(idx, f"label_{idx}") all_scores[reviewer_name] = float(prob) if prob > threshold: predicted_reviewers.append( {"reviewer": reviewer_name, "confidence": float(prob)} ) # Sort by confidence predicted_reviewers.sort(key=lambda x: x["confidence"], reverse=True) # Format output result_text = "## Prediction Results\n\n" result_text += f"**PR Title:** {pr_title}\n\n" result_text += f"**Files ({len(files_list)}):** {', '.join(files_list[:5])}" if len(files_list) > 5: result_text += f" ... and {len(files_list) - 5} more" result_text += f"\n\n**Threshold:** {threshold:.2f}\n\n" if predicted_reviewers: result_text += f"### Predicted Reviewers ({len(predicted_reviewers)})\n\n" for i, pred in enumerate(predicted_reviewers, 1): confidence_bar = "🟩" * int(pred["confidence"] * 10) result_text += f"{i}. **{pred['reviewer']}** - {pred['confidence']:.3f} {confidence_bar}\n" else: result_text += "### No Reviewers Predicted\n\n" result_text += "All confidence scores are below the threshold.\n" # Show top 5 scores regardless of threshold top_scores = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)[:5] result_text += "\n### Top 5 Confidence Scores\n\n" for reviewer, score in top_scores: confidence_bar = "🟦" * int(score * 10) result_text += f"- **{reviewer}**: {score:.3f} {confidence_bar}\n" # Create JSON output for all scores all_scores_json = json.dumps( { "predicted_reviewers": predicted_reviewers, "all_scores": all_scores, "threshold": threshold, "num_files": len(files_list), }, indent=2, ) return result_text, all_scores_json # Example inputs examples = [ [ "Fix authentication bug in user service", "auth.py, user.py, test_auth.py", 0.5, "", ], [ "Add new payment gateway integration", "gateway.py; payment_routes.py; config.py", 0.5, "", ], ] # Create Gradio interface with gr.Blocks(title="PR Reviewer Assignment", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # PR Reviewer Assignment Model This model predicts the best team members to review a Pull Request based on: - **PR Title/Description**: What the PR is about - **Modified Files**: Which files are being changed The model uses a fine-tuned **DeBERTa-large** model trained on historical PR patterns. """) with gr.Row(): with gr.Column(scale=2): pr_title_input = gr.Textbox( label="PR Title/Description", placeholder="e.g., Fix authentication bug in user service", lines=2, ) files_input = gr.Textbox( label="Modified Files (comma or semicolon separated)", placeholder="e.g., auth.py, user.py, test_auth.py", lines=3, ) threshold_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_THRESHOLD, step=0.05, label="Prediction Threshold", info="Only show predictions above this confidence score", ) with gr.Accordion("Custom Label Mapping (Optional)", open=False): gr.Markdown( """ If your deployed model has generic labels (e.g., `label_0`, `label_1`), you can paste your own ID to name mapping here in JSON format. **Example format:** ```json { "0": "John Doe", "1": "Jane Smith", "2": "Bob Johnson" } ``` """ ) custom_mapping_input = gr.Code( label="Custom ID to Label Mapping (JSON)", language="json", lines=10, value="", ) predict_btn = gr.Button("Predict Reviewers", variant="primary", size="lg") with gr.Column(scale=3): prediction_output = gr.Markdown(label="Predictions") with gr.Accordion("📋 Detailed JSON Output", open=False): json_output = gr.JSON(label="Full Prediction Details") # Connect the button predict_btn.click( fn=predict_reviewers, inputs=[pr_title_input, files_input, threshold_input, custom_mapping_input], outputs=[prediction_output, json_output], ) # Examples section gr.Markdown("### Example Inputs") gr.Examples( examples=examples, inputs=[pr_title_input, files_input, threshold_input, custom_mapping_input], outputs=[prediction_output, json_output], fn=predict_reviewers, cache_examples=False, ) gr.Markdown(""" --- ### Model Performance | Metric | Score | |--------|-------| | F1 Macro | 0.76 | | F1 Micro | 0.83 | | F1 Weighted | 0.82 | | Subset Accuracy | 0.83 | ### How to Use 1. **Enter PR Title**: Describe what the PR is about 2. **List Modified Files**: Enter file names separated by commas or semicolons 3. **Adjust Threshold** (optional): Lower threshold = more suggestions, Higher threshold = only high-confidence suggestions 4. **Click Predict**: Get reviewer recommendations with confidence scores ### Limitations - Model is trained on specific team patterns and may not generalize to other teams - Uses only file names and PR titles, not actual code changes - New team members may not be predicted accurately without historical data """) # Launch the app if __name__ == "__main__": demo.launch()