ErzhanAb commited on
Commit
6fb2a39
·
verified ·
1 Parent(s): 4686fa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -81
app.py CHANGED
@@ -2,36 +2,57 @@ import os, json, re
2
  from html import unescape
3
 
4
  import gradio as gr
5
- import torch
6
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from nltk.stem.snowball import RussianStemmer
 
8
 
9
- # =========================
10
- # 1) Константы/пути
11
- # =========================
12
- MODEL_DIR = "." # файлы (config.json, model.safetensors, tokenizer.*) лежат в корне
13
- INFER_CFG = os.path.join(MODEL_DIR, "inference_config.json")
14
-
15
- # Порог по умолчанию (если не нашли inference_config.json)
16
- DEFAULT_THRESHOLD = 0.40
17
- if os.path.exists(INFER_CFG):
18
- try:
19
- with open(INFER_CFG, "r", encoding="utf-8") as f:
20
- DEFAULT_THRESHOLD = float(json.load(f).get("threshold_val", DEFAULT_THRESHOLD))
21
- except Exception:
22
- pass
23
-
24
- # =========================
25
- # 2) Предобработка (та же, что при обучении!)
26
- # =========================
27
  _URL_RE = re.compile(r'https?://\S+|www\.\S+')
28
  _TAG_RE = re.compile(r'[@#]\w+')
29
  _NUM_RE = re.compile(r'\d+')
30
  _PUNCT_RE = re.compile(r"[^\w\s]+", flags=re.UNICODE)
31
  _WS_RE = re.compile(r"\s+")
32
 
33
- stemmer = RussianStemmer(ignore_stopwords=False)
34
-
35
  def clean_and_stem(s: str) -> str:
36
  if not isinstance(s, str):
37
  s = str(s)
@@ -48,85 +69,106 @@ def clean_and_stem(s: str) -> str:
48
  out.append(t if t in {"url", "tag", "num"} else stemmer.stem(t))
49
  return " ".join(out)
50
 
51
- # =========================
52
- # 3) Загрузка модели
53
- # =========================
54
- # Читаем локальные файлы — без скачивания с интернета
55
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
56
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
57
- model.eval()
58
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
59
- model.to(DEVICE)
60
-
61
- @torch.inference_mode()
62
- def infer_proba(text: str) -> float:
63
- text = clean_and_stem(text)
64
- if not text:
65
  return 0.0
66
- enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
67
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
68
- logits = model(**enc).logits
69
- probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
70
- return float(probs[1]) # P(toxic)
71
-
72
- # =========================
73
- # 4) Gradio UI
74
- # =========================
75
- TITLE = "Анализатор токсичности (ruBERT-tiny2)"
76
- DESCRIPTION = (
77
- "Введите комментарий на русском языке. Модель вернёт вероятности классов и метку по выбранному порогу."
78
- )
79
 
80
- CUSTOM_CSS = """
81
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
82
- :root { --font: 'Inter', system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
83
- """
 
84
 
85
- def predict(comment: str, threshold: float):
 
 
 
 
 
 
 
 
 
 
86
  comment = (comment or "").strip()
87
  if not comment:
88
  return {"Токсичный": 0.0, "Не токсичный": 1.0}, "—"
89
- p_toxic = infer_proba(comment)
 
 
 
 
 
90
  pred = "Токсичный" if p_toxic >= threshold else "Не токсичный"
91
  dist = {"Токсичный": p_toxic, "Не токсичный": 1 - p_toxic}
92
- expl = f"Порог: {threshold:.2f} • Вероятность токсичности: {p_toxic:.3f} → Предсказание: **{pred}**"
 
 
 
 
 
93
  return dist, expl
94
 
95
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
96
- css=CUSTOM_CSS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  gr.Markdown(f"# {TITLE}")
98
  gr.Markdown(DESCRIPTION)
99
 
100
  with gr.Row():
101
  with gr.Column(scale=2):
102
- inp = gr.Textbox(label="Текст комментария", lines=6, placeholder="Напишите что-нибудь…")
103
- thr = gr.Slider(label="Порог классификации", minimum=0.0, maximum=1.0,
104
- step=0.01, value=DEFAULT_THRESHOLD)
 
 
 
 
105
  with gr.Row():
106
- btn = gr.Button("Анализ", variant="primary")
107
- clr = gr.Button("Очистить", variant="secondary")
108
 
109
  with gr.Column(scale=1):
110
- out_label = gr.Label(label="Распределение по классам", num_top_classes=2)
111
- out_txt = gr.Markdown()
112
-
113
- examples = gr.Examples(
114
- examples=[
115
- ["да ты что, совсем с ума сошёл? это полный бред!", DEFAULT_THRESHOLD],
116
- ["спасибо за помощь, очень полезный совет!", DEFAULT_THRESHOLD],
117
- ],
118
- inputs=[inp, thr],
119
- label="Примеры"
120
- )
121
-
122
- btn.click(predict, [inp, thr], [out_label, out_txt])
123
- inp.submit(predict, [inp, thr], [out_label, out_txt])
124
 
125
- def _clear():
126
- return "", DEFAULT_THRESHOLD, {"Токсичный": 0.0, "Не токсичный": 1.0}, "—"
127
 
128
- clr.click(_clear, [], [inp, thr, out_label, out_txt])
 
 
129
 
130
  if __name__ == "__main__":
131
- # SSR по умолчанию у новых версий Gradio; дополнительных флагов не нужно
132
  demo.launch()
 
2
  from html import unescape
3
 
4
  import gradio as gr
5
+ import numpy as np
6
+
7
+ # ====== TF-IDF + LR (joblib / sklearn) ======
8
+ PIPE = None
9
+ try:
10
+ import joblib
11
+ PIPE = joblib.load("model.joblib") # сохранённый пайплайн TF-IDF+LR
12
+ except Exception as e:
13
+ PIPE = None
14
+ print(f"[WARN] Не удалось загрузить model.joblib: {e}")
15
+
16
+ # ====== Transformer (ruBERT-tiny2) ======
17
+ TRANSFORMER = {"model": None, "tokenizer": None, "device": "cpu"}
18
+ try:
19
+ import torch
20
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
+
22
+ MODEL_DIR = "." # в корне лежат config.json, model.safetensors, tokenizer.*
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
26
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
27
+ model.to(device).eval()
28
+
29
+ TRANSFORMER["model"] = model
30
+ TRANSFORMER["tokenizer"] = tokenizer
31
+ TRANSFORMER["device"] = device
32
+ except Exception as e:
33
+ print(f"[WARN] Не удалось загрузить ruBERT: {e}")
34
+
35
+ # ====== Порог по умолчанию ======
36
+ DEFAULT_THRESHOLD = 0.70 # как просили
37
+ # если есть inference_config.json от обучения трансформера — подхватим рекомендованный
38
+ try:
39
+ if os.path.exists("inference_config.json"):
40
+ with open("inference_config.json", "r", encoding="utf-8") as f:
41
+ cfg = json.load(f)
42
+ DEFAULT_THRESHOLD = float(cfg.get("threshold_val", DEFAULT_THRESHOLD))
43
+ except Exception:
44
+ pass
45
+
46
+ # ====== Предобработка для трансформера (как в обучении) ======
47
  from nltk.stem.snowball import RussianStemmer
48
+ stemmer = RussianStemmer(ignore_stopwords=False)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  _URL_RE = re.compile(r'https?://\S+|www\.\S+')
51
  _TAG_RE = re.compile(r'[@#]\w+')
52
  _NUM_RE = re.compile(r'\d+')
53
  _PUNCT_RE = re.compile(r"[^\w\s]+", flags=re.UNICODE)
54
  _WS_RE = re.compile(r"\s+")
55
 
 
 
56
  def clean_and_stem(s: str) -> str:
57
  if not isinstance(s, str):
58
  s = str(s)
 
69
  out.append(t if t in {"url", "tag", "num"} else stemmer.stem(t))
70
  return " ".join(out)
71
 
72
+ # ====== Инференс ======
73
+ def infer_tfidf(text: str) -> float:
74
+ """Вернёт P(toxic) из TF-IDF+LR. В пайплайне уже есть свой preprocessor."""
75
+ if PIPE is None:
 
 
 
 
 
 
 
 
 
 
76
  return 0.0
77
+ proba = PIPE.predict_proba([text])[0, 1]
78
+ return float(proba)
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def infer_transformer(text: str) -> float:
81
+ """Вернёт P(toxic) из ruBERT-tiny2 (локальный чекпойнт)."""
82
+ if TRANSFORMER["model"] is None:
83
+ return 0.0
84
+ import torch
85
 
86
+ text = clean_and_stem(text)
87
+ if not text:
88
+ return 0.0
89
+ tok = TRANSFORMER["tokenizer"](text, return_tensors="pt", truncation=True, max_length=256)
90
+ tok = {k: v.to(TRANSFORMER["device"]) for k, v in tok.items()}
91
+ with torch.inference_mode():
92
+ logits = TRANSFORMER["model"](**tok).logits
93
+ p = torch.softmax(logits, dim=1)[0, 1].detach().cpu().item()
94
+ return float(p)
95
+
96
+ def predict(model_name: str, comment: str, threshold: float):
97
  comment = (comment or "").strip()
98
  if not comment:
99
  return {"Токсичный": 0.0, "Не токсичный": 1.0}, "—"
100
+
101
+ if model_name == "ruBERT-tiny2 (трансформер)":
102
+ p_toxic = infer_transformer(comment)
103
+ else: # TF-IDF + Логистическая регрессия
104
+ p_toxic = infer_tfidf(comment)
105
+
106
  pred = "Токсичный" if p_toxic >= threshold else "Не токсичный"
107
  dist = {"Токсичный": p_toxic, "Не токсичный": 1 - p_toxic}
108
+ expl = (
109
+ f"Модель: **{model_name}** \n"
110
+ f"Порог: **{threshold:.2f}** \n"
111
+ f"Вероятность токсичности: **{p_toxic:.3f}** \n"
112
+ f"Предсказание: **{pred}**"
113
+ )
114
  return dist, expl
115
 
116
+ def clear_all():
117
+ return "ruBERT-tiny2 (трансформер)", "", DEFAULT_THRESHOLD, {"Токсичный": 0.0, "Не токсичный": 1.0}, "—"
118
+
119
+ # ====== UI ======
120
+ TITLE = "Анализатор токсичности (две модели)"
121
+ DESCRIPTION = "Выберите модель, задайте порог (по умолчанию 0.70) и введите комментарий."
122
+
123
+ CUSTOM_CSS = """
124
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
125
+ :root { --font: 'Inter', system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
126
+ """
127
+
128
+ ABOUT_MD = """
129
+ ### Параметры и описание моделей
130
+
131
+ **1) ruBERT-tiny2 (трансформер)**
132
+ - База: `cointegrated/rubert-tiny2` (BERT-tiny для русского).
133
+ - Токенизация: BERT WordPiece.
134
+ - Предобработка: удаление пунктуации, нормализация спец-токенов (`url`, `tag`, `num`), стемминг Snowball.
135
+ - Обучение: 10 эпох с early stopping (по macro-F1), class weights (balanced).
136
+ - Рекомендованный порог по валидации: ~**0.70**.
137
+
138
+ **2) TF-IDF + Логистическая регрессия**
139
+ - Векторизация: `TfidfVectorizer(analyzer="char_wb", ngram_range=(4,5), max_features=200k, min_df≈1.75e-4, max_df≈0.96)`.
140
+ - Классификатор: `LogisticRegression(penalty="l1", solver="liblinear", C≈5.52, class_weight="balanced", max_iter=5000, tol≈2.4e-4)`.
141
+ - Рекомендованный порог (по ранее полученным метрикам): ~**0.40**.
142
+
143
+ **Порог** можно свободно менять слайдером — выберите баланс precision/recall под задачу.
144
+ """
145
+
146
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=CUSTOM_CSS) as demo:
147
  gr.Markdown(f"# {TITLE}")
148
  gr.Markdown(DESCRIPTION)
149
 
150
  with gr.Row():
151
  with gr.Column(scale=2):
152
+ model_sel = gr.Dropdown(
153
+ ["ruBERT-tiny2 (трансформер)", "TF-IDF + Логистическая регрессия"],
154
+ value="ruBERT-tiny2 (трансформер)",
155
+ label="Модель"
156
+ )
157
+ comment_input = gr.Textbox(label="Текст комментария", lines=6, placeholder="Напишите что-нибудь…")
158
+ thr = gr.Slider(label="Порог классификации", minimum=0.0, maximum=1.0, value=DEFAULT_THRESHOLD, step=0.01)
159
  with gr.Row():
160
+ analyze_btn = gr.Button("Анализ", variant="primary")
161
+ clear_btn = gr.Button("Очистить", variant="secondary")
162
 
163
  with gr.Column(scale=1):
164
+ result_label = gr.Label(label="Распределение по классам", num_top_classes=2)
165
+ result_md = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ gr.Markdown(ABOUT_MD)
 
168
 
169
+ analyze_btn.click(predict, [model_sel, comment_input, thr], [result_label, result_md])
170
+ comment_input.submit(predict, [model_sel, comment_input, thr], [result_label, result_md])
171
+ clear_btn.click(clear_all, [], [model_sel, comment_input, thr, result_label, result_md])
172
 
173
  if __name__ == "__main__":
 
174
  demo.launch()