File size: 7,288 Bytes
0249c73 1367cdd a5bca72 6fb2a39 0249c73 1367cdd 0249c73 1367cdd 93f3e79 0f016d8 93f3e79 0249c73 90f8bf4 0249c73 93f3e79 6fb2a39 0249c73 1367cdd 0249c73 93f3e79 6fb2a39 93f3e79 6fb2a39 93f3e79 084b355 93f3e79 0249c73 93f3e79 0249c73 6fb2a39 0249c73 1367cdd 0249c73 93f3e79 fb9ddca 6fb2a39 93f3e79 90f8bf4 93f3e79 6fb2a39 783893f 93f3e79 fb9ddca 0f016d8 1367cdd 0f016d8 1367cdd 084b355 e413224 084b355 0f016d8 6918b0b e413224 0249c73 6fb2a39 e413224 6fb2a39 0249c73 e413224 0249c73 0f016d8 1367cdd 6fb2a39 1367cdd e413224 6fb2a39 1367cdd af1fd59 1367cdd e413224 084b355 1367cdd 6918b0b 084b355 73b552b 41796ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import gradio as gr
import numpy as np
# ==============================
# 1) КОНФИГ (без изменений)
# ==============================
MODEL_DIR = "rubert_tiny2_toxic_minprep"
MAX_LEN = 256
DEFAULT_THRESHOLD = 0.65
cfg_path = os.path.join(MODEL_DIR, "inference_config.json")
try:
if os.path.exists(cfg_path):
with open(cfg_path, "r", encoding="utf-8") as f:
DEFAULT_THRESHOLD = float(json.load(f).get("threshold_val", DEFAULT_THRESHOLD))
except Exception:
pass
# ==============================
# 2) ЗАГРУЗКА МОДЕЛИ (без изменений)
# ==============================
TRANSFORMER = {"model": None, "tokenizer": None, "device": "cpu", "loaded": False}
try:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
model.to(device).eval()
TRANSFORMER.update({
"model": model, "tokenizer": tokenizer, "device": device, "loaded": True
})
print(f"[INFO] Модель загружена из: {MODEL_DIR} | device={device}")
except Exception as e:
print(f"[WARN] Не удалось загрузить модель из '{MODEL_DIR}': {e}")
# ==============================
# 3) ИНФЕРЕНС (без изменений)
# ==============================
def infer(comment: str, threshold: float):
text = (comment or "").strip()
if not text:
return "—", {"Токсичный": 0.0, "Не токсичный": 1.0}
if not TRANSFORMER["loaded"]:
return "⚠️ Модель не загружена", {"Токсичный": 0.0, "Не токсичный": 1.0}
import torch
tok = TRANSFORMER["tokenizer"](text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
tok = {k: v.to(TRANSFORMER["device"]) for k, v in tok.items()}
with torch.inference_mode():
logits = TRANSFORMER["model"](**tok).logits
p_toxic = float(torch.softmax(logits, dim=1)[0, 1].detach().cpu().item())
verdict = "Токсичный" if p_toxic >= threshold else "Не токсичный"
dist = {"Токсичный": p_toxic, "Не токсичный": 1.0 - p_toxic}
return verdict, dist
# ==============================
# 4) ФУНКЦИИ-ОБЁРТКИ ДЛЯ UI (ИЗМЕНЕНО)
# ==============================
def predict_for_ui(comment: str, threshold: float):
"""
Вызывает infer и форматирует результат в красивый HTML-блок.
"""
verdict, dist = infer(comment, threshold)
if verdict == "Токсичный":
p_toxic = dist["Токсичный"]
return f"""
<div class='result-box toxic'>
<p class='verdict-text'>Токсичный</p>
<p class='probability-text'>Вероятность: {p_toxic:.1%}</p>
</div>
"""
elif verdict == "Не токсичный":
p_toxic = dist["Токсичный"]
return f"""
<div class='result-box neutral'>
<p class='verdict-text'>Не токсичный</p>
<p class='probability-text'>Вероятность: {p_toxic:.1%}</p>
</div>
"""
else:
# Для ошибок или пустого состояния
return f"<div class='result-box default'><p class='verdict-text'>{verdict}</p></div>"
def clear_all():
"""Сбрасывает UI к дефолтному состоянию."""
default_html = "<div class='result-box default'><p class='verdict-text'>—</p></div>"
return "", DEFAULT_THRESHOLD, default_html
# ==============================
# 5) UI (НОВЫЙ ДИЗАЙН И ВЫВОД)
# ==============================
TITLE = "Анализатор токсичности комментариев"
DESCRIPTION = "Модель на базе `ruBERT-tiny2` для определения токсичности в русскоязычном тексте."
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
:root { --font: 'Inter', system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
.result-box { padding: 1.5rem; border-radius: 0.5rem; text-align: center; }
.verdict-text { font-size: 1.75rem; font-weight: 700; margin: 0 0 0.5rem 0; }
.probability-text { font-size: 1rem; margin: 0; opacity: 0.8; }
.result-box.toxic { background-color: #ffebee; color: #c62828; }
.result-box.neutral { background-color: #e8f5e9; color: #2e7d32; }
.result-box.default { background-color: var(--neutral-100); color: var(--neutral-800); }
"""
with gr.Blocks(theme=gr.themes.Glass(), css=CUSTOM_CSS) as demo:
gr.Markdown(f"<div style='text-align: center;'><h1>{TITLE}</h1><p>{DESCRIPTION}</p></div>")
with gr.Row():
with gr.Column():
with gr.Group():
comment_input = gr.Textbox(
label="Текст для анализа", lines=7, placeholder="Напишите что-нибудь..."
)
thr = gr.Slider(
label="Порог классификации",
info="Определяет, при какой вероятности комментарий считается токсичным.",
minimum=0.0, maximum=1.0, value=DEFAULT_THRESHOLD, step=0.01
)
with gr.Row():
clear_btn = gr.Button("Очистить", variant="secondary")
analyze_btn = gr.Button("Анализировать", variant="primary")
gr.Markdown("---")
# Единственный блок вывода - красивый HTML-блок
verdict_output = gr.Markdown(
value="<div class='result-box default'><p class='verdict-text'>—</p></div>",
label="Вердикт (с учетом порога)"
)
with gr.Column(variant="panel"):
gr.Markdown(
"""
### О модели
- **База:** `cointegrated/rubert-tiny2` (облегчённый BERT), дообучен на задаче классификации токсичности.
- **Предобработка:** Минимальная (модель принимает сырой текст).
- **Макс. длина:** 256 токенов.
- **Рекомендованный порог:** `~0.65`. Повышение порога (до 0.70+) делает модель более строгой, снижая количество ложных срабатываний.
"""
)
analyze_btn.click(predict_for_ui, [comment_input, thr], [verdict_output])
comment_input.submit(predict_for_ui, [comment_input, thr], [verdict_output])
clear_btn.click(clear_all, [], [comment_input, thr, verdict_output])
if __name__ == "__main__":
demo.launch() |