import os import shutil from pathlib import Path from typing import Dict, List, Optional import pandas as pd import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, ) from datasets import Dataset import evaluate # ============================================================================ # 2. MODÈLE DE CLASSIFICATION PAR CATÉGORIE # ============================================================================ class CategoryClassificationModel: """ Modèle pour classifier automatiquement les annonces par catégorie """ def __init__(self, model_path: Optional[str] = None): self.model_path = model_path self.model = None self.tokenizer = None self.categories = [] def train(self, train_df: pd.DataFrame, test_df: pd.DataFrame, categories: List[str], output_dir: str = "./organized_data/category_classification/model_classification_moderation"): """ Entraîne le modèle de classification Args: train_df: DataFrame avec colonnes ['title', 'description', 'category'] test_df: DataFrame de test categories: Liste des catégories possibles output_dir: Chemin de sauvegarde """ print("=" * 70) print("ENTRAÎNEMENT DU MODÈLE DE CLASSIFICATION PAR CATÉGORIE") print("=" * 70) self.categories = categories num_labels = len(categories) # Mapper les catégories vers des entiers category_to_id = {cat: i for i, cat in enumerate(categories)} id_to_category = {i: cat for i, cat in enumerate(categories)} # Préparation des données train_df = train_df.copy() test_df = test_df.copy() train_df['text'] = ( train_df['title'].fillna('').astype(str) + " " + train_df['description'].fillna('').astype(str) ).str.strip() test_df['text'] = ( test_df['title'].fillna('').astype(str) + " " + test_df['description'].fillna('').astype(str) ).str.strip() train_df['category'] = train_df['category'].map(category_to_id) test_df['category'] = test_df['category'].map(category_to_id) train_dict = { 'text': train_df['text'].tolist(), 'category': train_df['category'].astype(int).tolist() } test_dict = { 'text': test_df['text'].tolist(), 'category': test_df['category'].astype(int).tolist() } train_dataset = Dataset.from_dict(train_dict) test_dataset = Dataset.from_dict(test_dict) # Modèle model_name = "distilbert-base-multilingual-cased" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, id2label=id_to_category, label2id=category_to_id ) # Tokenisation def tokenize_function(examples): result = self.tokenizer( examples['text'], padding='max_length', truncation=True, max_length=128 ) result['labels'] = examples['category'] return result train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=['text']) test_dataset = test_dataset.map(tokenize_function, batched=True, remove_columns=['text']) train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) # Entraînement base_dir = Path("./organized_data/text_moderation/model_text_moderation") logs_dir = base_dir / "logs" / "model_text_moderation" def clean_path(path): """Nettoie complètement un chemin (fichier ou dossier)""" if os.path.exists(path): try: if os.path.isfile(path): os.remove(path) print(f"✓ Fichier supprimé: {path}") elif os.path.isdir(path): shutil.rmtree(path) print(f"✓ Dossier supprimé: {path}") except Exception as e: print(f"⚠ Erreur lors du nettoyage de {path}: {e}") return False return True clean_path(str(logs_dir)) logs_dir.mkdir(parents=True, exist_ok=True) print(f"✓ Répertoire logs créé: {logs_dir}") training_args = TrainingArguments( output_dir=output_dir, eval_strategy="epoch", save_strategy="epoch", logging_dir=str(logs_dir), logging_steps=10, learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=10, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model="f1", report_to="none", ) accuracy_metric = evaluate.load("accuracy") f1_metric = evaluate.load("f1") def compute_metrics(eval_pred): logits, labels = eval_pred predictions = logits.argmax(axis=-1) return { "accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"], "f1": f1_metric.compute(predictions=predictions, references=labels, average='weighted')["f1"], } trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, compute_metrics=compute_metrics ) trainer.train() trainer.save_model(output_dir) self.tokenizer.save_pretrained(output_dir) self.model_path = output_dir # Sauvegarder les catégories import json with open(Path(output_dir) / "categories.json", "w") as f: json.dump(categories, f) print(f"\n✓ Modèle de classification sauvegardé dans: {output_dir}") return trainer.evaluate() def load(self, model_path: str): """Charge un modèle pré-entraîné""" import json self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) # Charger les catégories with open(Path(model_path) / "categories.json", "r") as f: self.categories = json.load(f) self.model_path = model_path print(f"✓ Modèle de classification chargé depuis: {model_path}") def predict(self, title: str, description: str, top_k: int = 3) -> Dict: """ Prédit la catégorie d'une annonce Returns: { 'category': str, 'confidence': float, 'top_predictions': List[Tuple[str, float]] } """ if self.model is None: raise ValueError("Modèle non chargé.") text = f"{title} {description}".strip() inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=128 ) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) # Top-K prédictions top_probs, top_indices = torch.topk(probabilities[0], k=min(top_k, len(self.categories))) top_predictions = [ (self.categories[idx.item()], prob.item()) for idx, prob in zip(top_indices, top_probs) ] return { 'category': top_predictions[0][0], 'confidence': top_predictions[0][1], 'top_predictions': top_predictions }