ErzhanAb's picture
Update app.py
e413224 verified
import os
import json
import gradio as gr
import numpy as np
# ==============================
# 1) КОНФИГ (без изменений)
# ==============================
MODEL_DIR = "rubert_tiny2_toxic_minprep"
MAX_LEN = 256
DEFAULT_THRESHOLD = 0.65
cfg_path = os.path.join(MODEL_DIR, "inference_config.json")
try:
if os.path.exists(cfg_path):
with open(cfg_path, "r", encoding="utf-8") as f:
DEFAULT_THRESHOLD = float(json.load(f).get("threshold_val", DEFAULT_THRESHOLD))
except Exception:
pass
# ==============================
# 2) ЗАГРУЗКА МОДЕЛИ (без изменений)
# ==============================
TRANSFORMER = {"model": None, "tokenizer": None, "device": "cpu", "loaded": False}
try:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
model.to(device).eval()
TRANSFORMER.update({
"model": model, "tokenizer": tokenizer, "device": device, "loaded": True
})
print(f"[INFO] Модель загружена из: {MODEL_DIR} | device={device}")
except Exception as e:
print(f"[WARN] Не удалось загрузить модель из '{MODEL_DIR}': {e}")
# ==============================
# 3) ИНФЕРЕНС (без изменений)
# ==============================
def infer(comment: str, threshold: float):
text = (comment or "").strip()
if not text:
return "—", {"Токсичный": 0.0, "Не токсичный": 1.0}
if not TRANSFORMER["loaded"]:
return "⚠️ Модель не загружена", {"Токсичный": 0.0, "Не токсичный": 1.0}
import torch
tok = TRANSFORMER["tokenizer"](text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
tok = {k: v.to(TRANSFORMER["device"]) for k, v in tok.items()}
with torch.inference_mode():
logits = TRANSFORMER["model"](**tok).logits
p_toxic = float(torch.softmax(logits, dim=1)[0, 1].detach().cpu().item())
verdict = "Токсичный" if p_toxic >= threshold else "Не токсичный"
dist = {"Токсичный": p_toxic, "Не токсичный": 1.0 - p_toxic}
return verdict, dist
# ==============================
# 4) ФУНКЦИИ-ОБЁРТКИ ДЛЯ UI (ИЗМЕНЕНО)
# ==============================
def predict_for_ui(comment: str, threshold: float):
"""
Вызывает infer и форматирует результат в красивый HTML-блок.
"""
verdict, dist = infer(comment, threshold)
if verdict == "Токсичный":
p_toxic = dist["Токсичный"]
return f"""
<div class='result-box toxic'>
<p class='verdict-text'>Токсичный</p>
<p class='probability-text'>Вероятность: {p_toxic:.1%}</p>
</div>
"""
elif verdict == "Не токсичный":
p_toxic = dist["Токсичный"]
return f"""
<div class='result-box neutral'>
<p class='verdict-text'>Не токсичный</p>
<p class='probability-text'>Вероятность: {p_toxic:.1%}</p>
</div>
"""
else:
# Для ошибок или пустого состояния
return f"<div class='result-box default'><p class='verdict-text'>{verdict}</p></div>"
def clear_all():
"""Сбрасывает UI к дефолтному состоянию."""
default_html = "<div class='result-box default'><p class='verdict-text'>—</p></div>"
return "", DEFAULT_THRESHOLD, default_html
# ==============================
# 5) UI (НОВЫЙ ДИЗАЙН И ВЫВОД)
# ==============================
TITLE = "Анализатор токсичности комментариев"
DESCRIPTION = "Модель на базе `ruBERT-tiny2` для определения токсичности в русскоязычном тексте."
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
:root { --font: 'Inter', system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
.result-box { padding: 1.5rem; border-radius: 0.5rem; text-align: center; }
.verdict-text { font-size: 1.75rem; font-weight: 700; margin: 0 0 0.5rem 0; }
.probability-text { font-size: 1rem; margin: 0; opacity: 0.8; }
.result-box.toxic { background-color: #ffebee; color: #c62828; }
.result-box.neutral { background-color: #e8f5e9; color: #2e7d32; }
.result-box.default { background-color: var(--neutral-100); color: var(--neutral-800); }
"""
with gr.Blocks(theme=gr.themes.Glass(), css=CUSTOM_CSS) as demo:
gr.Markdown(f"<div style='text-align: center;'><h1>{TITLE}</h1><p>{DESCRIPTION}</p></div>")
with gr.Row():
with gr.Column():
with gr.Group():
comment_input = gr.Textbox(
label="Текст для анализа", lines=7, placeholder="Напишите что-нибудь..."
)
thr = gr.Slider(
label="Порог классификации",
info="Определяет, при какой вероятности комментарий считается токсичным.",
minimum=0.0, maximum=1.0, value=DEFAULT_THRESHOLD, step=0.01
)
with gr.Row():
clear_btn = gr.Button("Очистить", variant="secondary")
analyze_btn = gr.Button("Анализировать", variant="primary")
gr.Markdown("---")
# Единственный блок вывода - красивый HTML-блок
verdict_output = gr.Markdown(
value="<div class='result-box default'><p class='verdict-text'>—</p></div>",
label="Вердикт (с учетом порога)"
)
with gr.Column(variant="panel"):
gr.Markdown(
"""
### О модели
- **База:** `cointegrated/rubert-tiny2` (облегчённый BERT), дообучен на задаче классификации токсичности.
- **Предобработка:** Минимальная (модель принимает сырой текст).
- **Макс. длина:** 256 токенов.
- **Рекомендованный порог:** `~0.65`. Повышение порога (до 0.70+) делает модель более строгой, снижая количество ложных срабатываний.
"""
)
analyze_btn.click(predict_for_ui, [comment_input, thr], [verdict_output])
comment_input.submit(predict_for_ui, [comment_input, thr], [verdict_output])
clear_btn.click(clear_all, [], [comment_input, thr, verdict_output])
if __name__ == "__main__":
demo.launch()