Spaces:
Running
Running
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from datetime import datetime | |
| import gradio as gr | |
| from typing import Dict, List, Union, Optional | |
| import logging | |
| import re | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ContentAnalyzer: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.tokenizer = None | |
| self.categories = [ | |
| "Violence", "Death", "Substance Use", "Gore", | |
| "Vomit", "Sexual Content", "Sexual Abuse", | |
| "Self-Harm", "Gun Use", "Animal Cruelty", | |
| "Mental Health Issues" | |
| ] | |
| self.pattern = re.compile(r'\b(' + '|'.join(self.categories) + r')\b', re.IGNORECASE) | |
| logger.info(f"Initialized analyzer with device: {self.device}") | |
| self._load_model() | |
| def _load_model(self) -> None: | |
| """Load model and tokenizer with CPU optimization""" | |
| try: | |
| logger.info("Loading model components...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
| use_fast=True, | |
| truncation_side="left" | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(self.device).eval() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| raise | |
| def _chunk_text(self, text: str, max_tokens: int = 512) -> List[str]: | |
| """Context-aware chunking with token counting""" | |
| paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for para in paragraphs: | |
| para_tokens = self.tokenizer.encode(para, add_special_tokens=False) | |
| para_length = len(para_tokens) | |
| if current_length + para_length > max_tokens and current_chunk: | |
| chunk_text = "\n\n".join(current_chunk) | |
| chunks.append(chunk_text) | |
| current_chunk = [para] | |
| current_length = para_length | |
| else: | |
| current_chunk.append(para) | |
| current_length += para_length | |
| if current_chunk: | |
| chunk_text = "\n\n".join(current_chunk) | |
| chunks.append(chunk_text) | |
| logger.info(f"Split text into {len(chunks)} chunks (max_tokens={max_tokens})") | |
| return chunks | |
| async def _analyze_chunk(self, chunk: str) -> tuple[List[str], str]: | |
| """Deep analysis with step-by-step reasoning""" | |
| prompt = f"""As a deep-thinking content analyzer, carefully evaluate this text for sensitive content. | |
| Input text: {chunk} | |
| Think through each step: | |
| 1. What is happening in the text? | |
| 2. What potentially sensitive themes or elements are present? | |
| 3. For each category below, is there clear evidence? | |
| Categories: {", ".join(self.categories)} | |
| Detailed analysis: | |
| """ | |
| try: | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| max_length=8192 | |
| ) | |
| full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract categories more reliably using multiple patterns | |
| categories_found = set() | |
| # Look for explicit category mentions | |
| category_matches = self.pattern.findall(full_response.lower()) | |
| # Normalize and validate matches | |
| for match in category_matches: | |
| for category in self.categories: | |
| if match.lower() == category.lower(): | |
| categories_found.add(category) | |
| # Convert to list and sort for consistency | |
| matched_categories = sorted(list(categories_found)) | |
| # Clean up reasoning text | |
| reasoning = full_response.split("\n\nCategories found:")[0] if "\n\nCategories found:" in full_response else full_response | |
| reasoning = reasoning.strip() | |
| if not matched_categories and any(trigger_word in full_response.lower() for trigger_word in | |
| ["concerning", "warning", "caution", "trigger", "sensitive"]): | |
| logger.warning(f"Potential triggers found but no categories matched in chunk") | |
| logger.info(f"Chunk analysis complete - Categories found: {matched_categories}") | |
| return matched_categories, reasoning | |
| except Exception as e: | |
| logger.error(f"Chunk analysis error: {str(e)}") | |
| return [], f"Analysis error: {str(e)}" | |
| async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> tuple[List[str], List[str]]: | |
| """Main analysis workflow with progress updates""" | |
| if not script.strip(): | |
| return ["No content provided"], ["No analysis performed"] | |
| identified_triggers = set() | |
| reasoning_outputs = [] | |
| chunks = self._chunk_text(script) | |
| if not chunks: | |
| return ["Empty text after chunking"], ["No analysis performed"] | |
| total_chunks = len(chunks) | |
| for idx, chunk in enumerate(chunks): | |
| if progress: | |
| progress((idx/total_chunks, f"Deep analysis of chunk {idx+1}/{total_chunks}")) | |
| chunk_triggers, chunk_reasoning = await self._analyze_chunk(chunk) | |
| identified_triggers.update(chunk_triggers) | |
| reasoning_outputs.append(f"Chunk {idx + 1} Analysis:\n{chunk_reasoning}") | |
| logger.info(f"Processed chunk {idx+1}/{total_chunks}, found triggers: {chunk_triggers}") | |
| if progress: | |
| progress((1.0, "Analysis complete")) | |
| final_triggers = sorted(list(identified_triggers)) if identified_triggers else ["None"] | |
| logger.info(f"Final triggers identified: {final_triggers}") | |
| return final_triggers, reasoning_outputs | |
| async def analyze_content( | |
| script: str, | |
| progress: Optional[gr.Progress] = None | |
| ) -> Dict[str, Union[List[str], str]]: | |
| """Gradio interface function with enhanced trigger detection""" | |
| try: | |
| analyzer = ContentAnalyzer() | |
| triggers, reasoning_output = await analyzer.analyze_script(script, progress) | |
| # Extract triggers from detailed analysis | |
| detected_triggers = set() | |
| full_reasoning = "\n\n".join(reasoning_output) | |
| # Look for explicit category markers | |
| category_markers = [ | |
| (r'\b(\w+):\s*\+', 1), # Matches "Category: +" | |
| (r'\*\*(\w+(?:\s+\w+)?):\*\*[^\n]*?\bMarked with "\+"', 1), # Matches "**Category:** ... Marked with "+" | |
| (r'(\w+(?:\s+\w+)?)\s*is clearly present', 1), # Matches "Category is clearly present" | |
| ] | |
| for pattern, group in category_markers: | |
| matches = re.finditer(pattern, full_reasoning, re.IGNORECASE) | |
| for match in matches: | |
| category = match.group(group).strip() | |
| # Normalize category names to match predefined categories | |
| for predefined_category in analyzer.categories: | |
| if category.lower() in predefined_category.lower(): | |
| detected_triggers.add(predefined_category) | |
| # Add any triggers found through direct pattern matching | |
| for category in analyzer.categories: | |
| pattern = fr'\b{re.escape(category)}\b.*?(present|evident|indicated|clear|obvious)' | |
| if re.search(pattern, full_reasoning, re.IGNORECASE): | |
| detected_triggers.add(category) | |
| # If no triggers were found through detailed analysis, fall back to original triggers | |
| final_triggers = sorted(list(detected_triggers)) if detected_triggers else triggers | |
| result = { | |
| "detected_triggers": final_triggers if final_triggers else ["None"], | |
| "confidence": "High confidence" if final_triggers and final_triggers != ["None"] else "No triggers found", | |
| "model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
| "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "analysis_reasoning": full_reasoning | |
| } | |
| logger.info(f"Enhanced analysis complete. Results: {result}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Analysis error: {str(e)}") | |
| return { | |
| "detected_triggers": ["Analysis error"], | |
| "confidence": "Error", | |
| "model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
| "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "analysis_reasoning": str(e), | |
| "error": str(e) | |
| } | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| fn=analyze_content, | |
| inputs=gr.Textbox(lines=12, label="Paste Script Here", placeholder="Enter text to analyze..."), | |
| outputs=[ | |
| gr.JSON(label="Analysis Results"), | |
| gr.Textbox(label="Analysis Reasoning", lines=10) | |
| ], | |
| title="TREAT - Trigger Analysis for Entertainment Texts", | |
| description="Deep analysis of scripts for sensitive content using AI", | |
| allow_flagging="never" | |
| ) | |
| iface.launch(show_error=True) |