Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app for PR Reviewer Assignment Model. | |
| This application provides an interactive interface for predicting PR reviewers | |
| based on PR title and modified files using a fine-tuned DeBERTa model. | |
| For private models, set the HF_TOKEN environment variable: | |
| export HF_TOKEN=your_huggingface_token | |
| """ | |
| import gradio as gr | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch | |
| import json | |
| import os | |
| # Model configuration | |
| MODEL_NAME = ( | |
| "yazoniak/pr-assignee-reviewer-deberta" # Update with your actual model name | |
| ) | |
| MAX_LENGTH = 8192 | |
| DEFAULT_THRESHOLD = 0.5 | |
| # Authentication token for private models | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| def load_model(): | |
| """ | |
| Load the model and tokenizer. | |
| For private models, requires HF_TOKEN environment variable to be set. | |
| Returns: | |
| tuple: (model, tokenizer, id2label) | |
| """ | |
| if HF_TOKEN: | |
| print(f"Using authentication token for private model: {MODEL_NAME}") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_NAME, token=HF_TOKEN | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
| else: | |
| print(f"No token found, attempting to load public model: {MODEL_NAME}") | |
| print("If this is a private model, set HF_TOKEN environment variable") | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| # Get label mappings from model config | |
| id2label = model.config.id2label | |
| if id2label and isinstance(list(id2label.keys())[0], str): | |
| id2label = {int(k): v for k, v in id2label.items()} | |
| return model, tokenizer, id2label | |
| # Load model at startup | |
| print("Loading model...") | |
| model, tokenizer, id2label = load_model() | |
| print(f"Model loaded successfully with {len(id2label)} reviewers") | |
| def predict_reviewers( | |
| pr_title: str, | |
| files_input: str, | |
| threshold: float = DEFAULT_THRESHOLD, | |
| custom_mapping: str = "", | |
| ) -> tuple[str, str]: | |
| """ | |
| Predict reviewers for a PR based on title and modified files. | |
| Args: | |
| pr_title: The PR title/description | |
| files_input: Comma or semicolon separated list of modified files | |
| threshold: Prediction threshold (0-1) | |
| custom_mapping: Optional JSON mapping of label IDs to names | |
| Returns: | |
| tuple: (formatted_predictions, all_scores_json) | |
| """ | |
| # Validate inputs | |
| if not pr_title or not pr_title.strip(): | |
| return "⚠️ Please enter a PR title", "" | |
| if not files_input or not files_input.strip(): | |
| return "⚠️ Please enter at least one file", "" | |
| # Parse files list | |
| files_list = [] | |
| for separator in [",", ";"]: | |
| if separator in files_input: | |
| files_list = [f.strip() for f in files_input.split(separator) if f.strip()] | |
| break | |
| if not files_list: | |
| files_list = [files_input.strip()] | |
| # Validate threshold | |
| if threshold < 0 or threshold > 1: | |
| return "⚠️ Threshold must be between 0 and 1", "" | |
| # Parse custom mapping if provided | |
| label_mapping = id2label # Default to model's labels | |
| if custom_mapping and custom_mapping.strip(): | |
| try: | |
| parsed_mapping = json.loads(custom_mapping) | |
| # Convert string keys to integers | |
| label_mapping = {int(k): v for k, v in parsed_mapping.items()} | |
| except json.JSONDecodeError: | |
| return "⚠️ Invalid JSON format for custom mapping", "" | |
| except (ValueError, TypeError): | |
| return "⚠️ Custom mapping must have numeric keys", "" | |
| # Format input for the model | |
| files_text = f"files: {', '.join(files_list)}" | |
| # Tokenize | |
| inputs = tokenizer( | |
| [pr_title], | |
| text_pair=[files_text], | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.sigmoid(logits).numpy()[0] | |
| # Get predictions above threshold | |
| predicted_reviewers = [] | |
| all_scores = {} | |
| for idx, prob in enumerate(probabilities): | |
| reviewer_name = label_mapping.get(idx, f"label_{idx}") | |
| all_scores[reviewer_name] = float(prob) | |
| if prob > threshold: | |
| predicted_reviewers.append( | |
| {"reviewer": reviewer_name, "confidence": float(prob)} | |
| ) | |
| # Sort by confidence | |
| predicted_reviewers.sort(key=lambda x: x["confidence"], reverse=True) | |
| # Format output | |
| result_text = "## Prediction Results\n\n" | |
| result_text += f"**PR Title:** {pr_title}\n\n" | |
| result_text += f"**Files ({len(files_list)}):** {', '.join(files_list[:5])}" | |
| if len(files_list) > 5: | |
| result_text += f" ... and {len(files_list) - 5} more" | |
| result_text += f"\n\n**Threshold:** {threshold:.2f}\n\n" | |
| if predicted_reviewers: | |
| result_text += f"### Predicted Reviewers ({len(predicted_reviewers)})\n\n" | |
| for i, pred in enumerate(predicted_reviewers, 1): | |
| confidence_bar = "🟩" * int(pred["confidence"] * 10) | |
| result_text += f"{i}. **{pred['reviewer']}** - {pred['confidence']:.3f} {confidence_bar}\n" | |
| else: | |
| result_text += "### No Reviewers Predicted\n\n" | |
| result_text += "All confidence scores are below the threshold.\n" | |
| # Show top 5 scores regardless of threshold | |
| top_scores = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)[:5] | |
| result_text += "\n### Top 5 Confidence Scores\n\n" | |
| for reviewer, score in top_scores: | |
| confidence_bar = "🟦" * int(score * 10) | |
| result_text += f"- **{reviewer}**: {score:.3f} {confidence_bar}\n" | |
| # Create JSON output for all scores | |
| all_scores_json = json.dumps( | |
| { | |
| "predicted_reviewers": predicted_reviewers, | |
| "all_scores": all_scores, | |
| "threshold": threshold, | |
| "num_files": len(files_list), | |
| }, | |
| indent=2, | |
| ) | |
| return result_text, all_scores_json | |
| # Example inputs | |
| examples = [ | |
| [ | |
| "Fix authentication bug in user service", | |
| "auth.py, user.py, test_auth.py", | |
| 0.5, | |
| "", | |
| ], | |
| [ | |
| "Add new payment gateway integration", | |
| "gateway.py; payment_routes.py; config.py", | |
| 0.5, | |
| "", | |
| ], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="PR Reviewer Assignment", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # PR Reviewer Assignment Model | |
| This model predicts the best team members to review a Pull Request based on: | |
| - **PR Title/Description**: What the PR is about | |
| - **Modified Files**: Which files are being changed | |
| The model uses a fine-tuned **DeBERTa-large** model trained on historical PR patterns. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| pr_title_input = gr.Textbox( | |
| label="PR Title/Description", | |
| placeholder="e.g., Fix authentication bug in user service", | |
| lines=2, | |
| ) | |
| files_input = gr.Textbox( | |
| label="Modified Files (comma or semicolon separated)", | |
| placeholder="e.g., auth.py, user.py, test_auth.py", | |
| lines=3, | |
| ) | |
| threshold_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_THRESHOLD, | |
| step=0.05, | |
| label="Prediction Threshold", | |
| info="Only show predictions above this confidence score", | |
| ) | |
| with gr.Accordion("Custom Label Mapping (Optional)", open=False): | |
| gr.Markdown( | |
| """ | |
| If your deployed model has generic labels (e.g., `label_0`, `label_1`), | |
| you can paste your own ID to name mapping here in JSON format. | |
| **Example format:** | |
| ```json | |
| { | |
| "0": "John Doe", | |
| "1": "Jane Smith", | |
| "2": "Bob Johnson" | |
| } | |
| ``` | |
| """ | |
| ) | |
| custom_mapping_input = gr.Code( | |
| label="Custom ID to Label Mapping (JSON)", | |
| language="json", | |
| lines=10, | |
| value="", | |
| ) | |
| predict_btn = gr.Button("Predict Reviewers", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| prediction_output = gr.Markdown(label="Predictions") | |
| with gr.Accordion("📋 Detailed JSON Output", open=False): | |
| json_output = gr.JSON(label="Full Prediction Details") | |
| # Connect the button | |
| predict_btn.click( | |
| fn=predict_reviewers, | |
| inputs=[pr_title_input, files_input, threshold_input, custom_mapping_input], | |
| outputs=[prediction_output, json_output], | |
| ) | |
| # Examples section | |
| gr.Markdown("### Example Inputs") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[pr_title_input, files_input, threshold_input, custom_mapping_input], | |
| outputs=[prediction_output, json_output], | |
| fn=predict_reviewers, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### Model Performance | |
| | Metric | Score | | |
| |--------|-------| | |
| | F1 Macro | 0.76 | | |
| | F1 Micro | 0.83 | | |
| | F1 Weighted | 0.82 | | |
| | Subset Accuracy | 0.83 | | |
| ### How to Use | |
| 1. **Enter PR Title**: Describe what the PR is about | |
| 2. **List Modified Files**: Enter file names separated by commas or semicolons | |
| 3. **Adjust Threshold** (optional): Lower threshold = more suggestions, Higher threshold = only high-confidence suggestions | |
| 4. **Click Predict**: Get reviewer recommendations with confidence scores | |
| ### Limitations | |
| - Model is trained on specific team patterns and may not generalize to other teams | |
| - Uses only file names and PR titles, not actual code changes | |
| - New team members may not be predicted accurately without historical data | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |