yazoniak commited on
Commit
2166d44
·
verified ·
1 Parent(s): 82d8d08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -3,12 +3,16 @@ Gradio app for PR Reviewer Assignment Model.
3
 
4
  This application provides an interactive interface for predicting PR reviewers
5
  based on PR title and modified files using a fine-tuned DeBERTa model.
 
 
 
6
  """
7
 
8
  import gradio as gr
9
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
  import torch
11
  import json
 
12
 
13
 
14
  # Model configuration
@@ -18,16 +22,31 @@ MODEL_NAME = (
18
  MAX_LENGTH = 8192
19
  DEFAULT_THRESHOLD = 0.5
20
 
 
 
 
21
 
22
  def load_model():
23
  """
24
  Load the model and tokenizer.
25
 
 
 
26
  Returns:
27
  tuple: (model, tokenizer, id2label)
28
  """
29
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
30
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
31
  model.eval()
32
 
33
  # Get label mappings from model config
 
3
 
4
  This application provides an interactive interface for predicting PR reviewers
5
  based on PR title and modified files using a fine-tuned DeBERTa model.
6
+
7
+ For private models, set the HF_TOKEN environment variable:
8
+ export HF_TOKEN=your_huggingface_token
9
  """
10
 
11
  import gradio as gr
12
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
  import torch
14
  import json
15
+ import os
16
 
17
 
18
  # Model configuration
 
22
  MAX_LENGTH = 8192
23
  DEFAULT_THRESHOLD = 0.5
24
 
25
+ # Authentication token for private models
26
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+
28
 
29
  def load_model():
30
  """
31
  Load the model and tokenizer.
32
 
33
+ For private models, requires HF_TOKEN environment variable to be set.
34
+
35
  Returns:
36
  tuple: (model, tokenizer, id2label)
37
  """
38
+ if HF_TOKEN:
39
+ print(f"Using authentication token for private model: {MODEL_NAME}")
40
+ model = AutoModelForSequenceClassification.from_pretrained(
41
+ MODEL_NAME, token=HF_TOKEN
42
+ )
43
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
44
+ else:
45
+ print(f"No token found, attempting to load public model: {MODEL_NAME}")
46
+ print("If this is a private model, set HF_TOKEN environment variable")
47
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
+
50
  model.eval()
51
 
52
  # Get label mappings from model config