ErzhanAb commited on
Commit
a5bca72
·
verified ·
1 Parent(s): e146310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -83
app.py CHANGED
@@ -1,112 +1,134 @@
1
- import gradio as gr
2
- import joblib, json, re
3
  from html import unescape
4
 
 
 
 
 
 
5
  # -----------------------------
6
- # 1) Точная копия preprocessor (без изменений)
7
  # -----------------------------
8
- _URL_RE = re.compile(r'https?://\S+|www\.\S+')
9
- _TAG_RE = re.compile(r'[@#]\w+')
10
- _NUM_RE = re.compile(r'\d+')
11
- _WS_RE = re.compile(r'\s+')
 
12
 
13
- def clean_text(s: str) -> str:
14
- """ДОЛЖНА совпадать с версией из обучения, иначе pickle не найдёт функцию."""
 
15
  if not isinstance(s, str):
16
  s = str(s)
17
  s = unescape(s).lower()
18
- s = _URL_RE.sub(' <url> ', s)
19
- s = _TAG_RE.sub(' <tag> ', s)
20
- s = _NUM_RE.sub(' <num> ', s)
21
- s = s.replace('\n', ' ').replace('\t', ' ')
22
- s = _WS_RE.sub(' ', s).strip()
23
- return s
24
-
25
- # ---------------------------------
26
- # 2) Загрузка пайплайна (без конфига)
27
- # ---------------------------------
28
- PIPE = joblib.load("model.joblib")
29
-
30
- # ---------------------------------
31
- # 3) Максимально упрощенный инференс
32
- # ---------------------------------
33
- def predict(comment: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  """
35
- Возвращает словарь {метка: вероятность} для компонента gr.Label.
 
 
36
  """
37
- if not comment or not comment.strip():
38
- return None # Возвращаем None, чтобы очистить поле вывода
39
-
40
- proba_toxic = float(PIPE.predict_proba([comment])[0, 1])
41
- proba_not_toxic = 1 - proba_toxic
42
-
43
- # Возвращаем словарь, gr.Label сам подсветит класс с большей вероятностью
44
- return {"Токсичный": proba_toxic, "Не токсичный": proba_not_toxic}
45
-
46
- # ---------------------------------
47
- # 4) Минималистичный интерфейс
48
- # ---------------------------------
49
-
50
- TITLE = "Анализатор токсичности комментариев"
51
- DESCRIPTION = "Введите комментарий на русском языке. Модель покажет распределение вероятностей между классами «Токсичный» и «Не токсичный»."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ARTICLE = """
53
- ---
54
- ### Детали модели
55
- * **Архитектура**: TF-IDF (char_wb, n-граммы 4-5) + Логистическая регрессия (L1, class_weight=balanced).
56
- * **Назначение**: Классификация русскоязычных текстов.
57
  """
58
 
59
- # CSS для подключения и применения шрифта "Inter"
60
  CUSTOM_CSS = """
61
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;700&display=swap');
62
- gradio-app {
63
- font-family: 'Inter', sans-serif;
64
- }
65
  """
66
 
67
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=CUSTOM_CSS) as demo:
 
68
  gr.Markdown(f"# {TITLE}")
69
  gr.Markdown(DESCRIPTION)
70
 
71
  with gr.Row():
72
- # Левая колонка для ввода
73
  with gr.Column(scale=2):
74
- comment_input = gr.Textbox(
75
- label="Текст комментария",
76
- lines=5,
77
- placeholder="Напишите что-нибудь...",
78
- )
79
  with gr.Row():
80
- clear_btn = gr.Button("Очистить", variant="secondary")
81
- analyze_btn = gr.Button("Анализ", variant="primary")
82
 
83
- # Правая колонка для вывода
84
  with gr.Column(scale=1):
85
- result_label = gr.Label(label="Результат", num_top_classes=2)
86
-
 
87
  gr.Markdown(ARTICLE)
88
-
89
- # --- Логика взаимодействия компонентов ---
90
-
91
- def clear_all():
92
- return "", None
93
-
94
- # Привязка функций к кнопкам и событиям
95
- analyze_btn.click(
96
- fn=predict,
97
- inputs=comment_input,
98
- outputs=result_label
99
- )
100
- comment_input.submit(
101
- fn=predict,
102
- inputs=comment_input,
103
- outputs=result_label
104
- )
105
- clear_btn.click(
106
- fn=clear_all,
107
- inputs=[],
108
- outputs=[comment_input, result_label]
109
- )
110
 
111
  if __name__ == "__main__":
112
  demo.launch()
 
1
+ import os, json, re
 
2
  from html import unescape
3
 
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from nltk.stem.snowball import RussianStemmer
8
+
9
  # -----------------------------
10
+ # 1) Предобработка: как в обучении
11
  # -----------------------------
12
+ _URL_RE = re.compile(r'https?://\S+|www\.\S+')
13
+ _TAG_RE = re.compile(r'[@#]\w+')
14
+ _NUM_RE = re.compile(r'\d+')
15
+ _PUNCT_RE = re.compile(r"[^\w\s]+", flags=re.UNICODE)
16
+ _WS_RE = re.compile(r"\s+")
17
 
18
+ stemmer = RussianStemmer(ignore_stopwords=False)
19
+
20
+ def clean_and_stem(s: str) -> str:
21
  if not isinstance(s, str):
22
  s = str(s)
23
  s = unescape(s).lower()
24
+ s = _URL_RE.sub(" url ", s)
25
+ s = _TAG_RE.sub(" tag ", s)
26
+ s = _NUM_RE.sub(" num ", s)
27
+ s = _PUNCT_RE.sub(" ", s)
28
+ s = _WS_RE.sub(" ", s).strip()
29
+ if not s:
30
+ return s
31
+ out = []
32
+ for t in s.split():
33
+ out.append(t if t in {"url", "tag", "num"} else stemmer.stem(t))
34
+ return " ".join(out)
35
+
36
+ # -----------------------------
37
+ # 2) Загрузка модели/токенайзера
38
+ # -----------------------------
39
+ # Папка с файлами модели; по умолчанию 'best', можно переопределить переменной окружения MODEL_DIR
40
+ MODEL_DIR = os.getenv("MODEL_DIR", "best" if os.path.exists("best/config.json") else ".")
41
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
44
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
45
+ model.to(DEVICE).eval()
46
+
47
+ # Порог по умолчанию берём из inference_config.json (если есть), иначе 0.40
48
+ DEFAULT_THRESHOLD = 0.40
49
+ try:
50
+ with open(os.path.join(MODEL_DIR, "inference_config.json"), "r", encoding="utf-8") as f:
51
+ DEFAULT_THRESHOLD = float(json.load(f).get("threshold_val", DEFAULT_THRESHOLD))
52
+ except Exception:
53
+ pass
54
+
55
+ MAX_LEN = 256
56
+
57
+ @torch.no_grad()
58
+ def predict(text: str, threshold: float):
59
  """
60
+ Возвращает:
61
+ - dict для gr.Label с распределением вероятностей
62
+ - пояснение (Markdown)
63
  """
64
+ if not text or not text.strip():
65
+ return None, "Введите текст выше."
66
+
67
+ text_prep = clean_and_stem(text)
68
+ batch = tokenizer(text_prep, truncation=True, max_length=MAX_LEN,
69
+ padding=True, return_tensors="pt")
70
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
71
+
72
+ logits = model(**batch).logits
73
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
74
+ p_not, p_tox = float(probs[0]), float(probs[1])
75
+
76
+ label = "Токсичный" if p_tox >= threshold else "Не токсичный"
77
+ md = (
78
+ f"**Порог:** {threshold:.2f} \n"
79
+ f"**Класс:** **{label}** \n"
80
+ f"P(toxic) = {p_tox:.3f} · P(not_toxic) = {p_not:.3f}"
81
+ )
82
+ return {"Токсичный": p_tox, "Не ток��ичный": p_not}, md
83
+
84
+ # -----------------------------
85
+ # 3) UI
86
+ # -----------------------------
87
+ TITLE = "Анализатор токсичности комментариев (ruBERT-tiny2)"
88
+ DESCRIPTION = (
89
+ "Введите комментарий на русском языке. "
90
+ "Модель (cointegrated/rubert-tiny2, дообученная на вашем датасете) "
91
+ "вернёт распределение вероятностей и итоговую метку с учётом порога."
92
+ )
93
  ARTICLE = """
94
+ ### Детали
95
+ - Архитектура: **ruBERT-tiny2** → linear head (2 класса).
96
+ - Предобработка: замена URL/тегов/чисел, удаление пунктуации, **стемминг** (NLTK).
97
+ - Ввод порога позволяет управлять балансом Precision/Recall.
98
  """
99
 
 
100
  CUSTOM_CSS = """
101
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;700&display=swap');
102
+ gradio-app { font-family: 'Inter', sans-serif; }
 
 
103
  """
104
 
105
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
106
+ css=CUSTOM_CSS, flagging_mode="never") as demo:
107
  gr.Markdown(f"# {TITLE}")
108
  gr.Markdown(DESCRIPTION)
109
 
110
  with gr.Row():
 
111
  with gr.Column(scale=2):
112
+ txt = gr.Textbox(label="Текст комментария", lines=6,
113
+ placeholder="Напишите что-нибудь…")
114
+ thr = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01,
115
+ label="Порог классификации")
 
116
  with gr.Row():
117
+ btn_clear = gr.Button("Очистить", variant="secondary")
118
+ btn_pred = gr.Button("Анализ", variant="primary")
119
 
 
120
  with gr.Column(scale=1):
121
+ dist = gr.Label(label="Распределение вероятностей", num_top_classes=2)
122
+ info = gr.Markdown()
123
+
124
  gr.Markdown(ARTICLE)
125
+
126
+ def _clear():
127
+ return "", DEFAULT_THRESHOLD, None, " "
128
+
129
+ btn_pred.click(predict, inputs=[txt, thr], outputs=[dist, info])
130
+ txt.submit(predict, inputs=[txt, thr], outputs=[dist, info])
131
+ btn_clear.click(_clear, outputs=[txt, thr, dist, info])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  if __name__ == "__main__":
134
  demo.launch()