Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering | |
| from sentence_transformers import SentenceTransformer, util | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional | |
| import re | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import json | |
| import os | |
| class EnhancedModelManager: | |
| """ | |
| Enhanced model manager with ensemble methods, better prompting, and multiple models | |
| for improved accuracy in legal document analysis. | |
| """ | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.models = {} | |
| self.embedders = {} | |
| self.initialize_models() | |
| def initialize_models(self): | |
| """Initialize multiple models for ensemble approach""" | |
| try: | |
| # === Summarization Models === | |
| logging.info("Loading summarization models...") | |
| # Only the legal-specific summarizer | |
| self.models['legal_summarizer'] = pipeline( | |
| "summarization", | |
| model="TheGod-2003/legal-summarizer", | |
| tokenizer="TheGod-2003/legal-summarizer", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| logging.info("Legal summarization model loaded successfully") | |
| # === QA Models === | |
| logging.info("Loading QA models...") | |
| # Primary legal QA model | |
| self.models['legal_qa'] = pipeline( | |
| "question-answering", | |
| model="TheGod-2003/legal_QA_model", | |
| tokenizer="TheGod-2003/legal_QA_model", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| # Alternative QA models | |
| try: | |
| self.models['bert_qa'] = pipeline( | |
| "question-answering", | |
| model="deepset/roberta-base-squad2", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| except Exception as e: | |
| logging.warning(f"Could not load RoBERTa QA model: {e}") | |
| try: | |
| self.models['distilbert_qa'] = pipeline( | |
| "question-answering", | |
| model="distilbert-base-cased-distilled-squad", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| except Exception as e: | |
| logging.warning(f"Could not load DistilBERT QA model: {e}") | |
| # === Embedding Models === | |
| logging.info("Loading embedding models...") | |
| # Primary embedding model | |
| self.embedders['mpnet'] = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
| # Alternative embedding models for ensemble | |
| try: | |
| self.embedders['all_minilm'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| except Exception as e: | |
| logging.warning(f"Could not load all-MiniLM embedder: {e}") | |
| try: | |
| self.embedders['paraphrase'] = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
| except Exception as e: | |
| logging.warning(f"Could not load paraphrase embedder: {e}") | |
| logging.info("All models loaded successfully") | |
| except Exception as e: | |
| logging.error(f"Error initializing models: {e}") | |
| raise | |
| def generate_enhanced_summary(self, text: str, max_length: int = 4096, min_length: int = 200) -> Dict[str, Any]: | |
| """ | |
| Generate enhanced summary using ensemble approach with multiple models | |
| """ | |
| try: | |
| summaries = [] | |
| weights = [] | |
| cleaned_text = self._preprocess_text(text) | |
| # Handle long documents with improved chunking | |
| cleaned_text = self._handle_long_documents(cleaned_text) | |
| # Only legal summarizer | |
| if 'legal_summarizer' in self.models: | |
| try: | |
| # Improved parameters for LED-16384 model | |
| summary = self.models['legal_summarizer']( | |
| cleaned_text, | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=5, # Increased for better quality | |
| length_penalty=1.2, # Slightly favor longer summaries | |
| repetition_penalty=1.5, # Reduced to avoid over-penalization | |
| no_repeat_ngram_size=2, # Reduced for legal text | |
| early_stopping=False, # Disabled to prevent premature stopping | |
| do_sample=True, # Enable sampling for better diversity | |
| temperature=0.7, # Add some randomness | |
| top_p=0.9, # Nucleus sampling | |
| pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id, | |
| eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id | |
| )[0]['summary_text'] | |
| # Ensure summary is complete | |
| summary = self._ensure_complete_summary(summary, cleaned_text) | |
| # Retry if summary is too short or incomplete | |
| if len(summary.split()) < min_length or not summary.strip().endswith(('.', '!', '?')): | |
| logging.info("Summary too short or incomplete, retrying with different parameters...") | |
| retry_summary = self.models['legal_summarizer']( | |
| cleaned_text, | |
| max_length=max_length * 2, # Double the max length | |
| min_length=min_length, | |
| num_beams=3, # Reduce beams for faster generation | |
| length_penalty=1.5, # Favor longer summaries | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=1, | |
| early_stopping=False, | |
| do_sample=False, # Disable sampling for more deterministic output | |
| pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id, | |
| eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id | |
| )[0]['summary_text'] | |
| retry_summary = self._ensure_complete_summary(retry_summary, cleaned_text) | |
| if len(retry_summary.split()) > len(summary.split()): | |
| summary = retry_summary | |
| summaries.append(summary) | |
| weights.append(1.0) | |
| except Exception as e: | |
| logging.warning(f"Legal summarizer failed: {e}") | |
| # Fallback to extractive summarization | |
| fallback_summary = self._extractive_summarization(cleaned_text, max_length) | |
| if fallback_summary: | |
| summaries.append(fallback_summary) | |
| weights.append(1.0) | |
| if not summaries: | |
| raise Exception("No models could generate summaries") | |
| final_summary = self._ensemble_summaries(summaries, weights) | |
| final_summary = self._postprocess_summary(final_summary, summaries, min_sentences=8) | |
| return { | |
| 'summary': final_summary, | |
| 'model_summaries': summaries, | |
| 'weights': weights, | |
| 'confidence': self._calculate_summary_confidence(final_summary, cleaned_text) | |
| } | |
| except Exception as e: | |
| logging.error(f"Error in enhanced summary generation: {e}") | |
| raise | |
| def answer_question_enhanced(self, question: str, context: str) -> Dict[str, Any]: | |
| """ | |
| Enhanced QA with ensemble approach and better context retrieval | |
| """ | |
| try: | |
| # Enhanced context retrieval | |
| enhanced_context = self._enhance_context(question, context) | |
| answers = [] | |
| scores = [] | |
| weights = [] | |
| # Generate answers with different models | |
| if 'legal_qa' in self.models: | |
| try: | |
| result = self.models['legal_qa']( | |
| question=question, | |
| context=enhanced_context | |
| ) | |
| answers.append(result['answer']) | |
| scores.append(result['score']) | |
| weights.append(0.5) # Higher weight for legal-specific model | |
| except Exception as e: | |
| logging.warning(f"Legal QA model failed: {e}") | |
| if 'bert_qa' in self.models: | |
| try: | |
| result = self.models['bert_qa']( | |
| question=question, | |
| context=enhanced_context | |
| ) | |
| answers.append(result['answer']) | |
| scores.append(result['score']) | |
| weights.append(0.3) | |
| except Exception as e: | |
| logging.warning(f"RoBERTa QA model failed: {e}") | |
| if 'distilbert_qa' in self.models: | |
| try: | |
| result = self.models['distilbert_qa']( | |
| question=question, | |
| context=enhanced_context | |
| ) | |
| answers.append(result['answer']) | |
| scores.append(result['score']) | |
| weights.append(0.2) | |
| except Exception as e: | |
| logging.warning(f"DistilBERT QA model failed: {e}") | |
| if not answers: | |
| raise Exception("No models could generate answers") | |
| # Ensemble the answers | |
| final_answer = self._ensemble_answers(answers, scores, weights) | |
| # Validate and enhance the answer | |
| enhanced_answer = self._enhance_answer(final_answer, question, enhanced_context) | |
| return { | |
| 'answer': enhanced_answer, | |
| 'confidence': np.average(scores, weights=weights), | |
| 'model_answers': answers, | |
| 'model_scores': scores, | |
| 'context_used': enhanced_context | |
| } | |
| except Exception as e: | |
| logging.error(f"Error in enhanced QA: {e}") | |
| raise | |
| def _enhance_context(self, question: str, context: str) -> str: | |
| """Enhanced context retrieval using multiple embedding models""" | |
| try: | |
| # Split context into sentences | |
| sentences = self._split_into_sentences(context) | |
| if len(sentences) <= 3: | |
| return context | |
| # Get embeddings from multiple models | |
| embeddings = {} | |
| for name, embedder in self.embedders.items(): | |
| try: | |
| sentence_embeddings = embedder.encode(sentences, convert_to_tensor=True) | |
| question_embedding = embedder.encode(question, convert_to_tensor=True) | |
| similarities = util.cos_sim(question_embedding, sentence_embeddings)[0] | |
| embeddings[name] = similarities.cpu().numpy() | |
| except Exception as e: | |
| logging.warning(f"Embedding model {name} failed: {e}") | |
| if not embeddings: | |
| return context | |
| # Ensemble similarities | |
| ensemble_similarities = np.mean(list(embeddings.values()), axis=0) | |
| # Get top sentences | |
| top_indices = np.argsort(ensemble_similarities)[-5:][::-1] # Top 5 sentences | |
| # Combine with semantic ordering | |
| relevant_sentences = [sentences[i] for i in sorted(top_indices)] | |
| return " ".join(relevant_sentences) | |
| except Exception as e: | |
| logging.warning(f"Context enhancement failed: {e}") | |
| return context | |
| def _ensemble_summaries(self, summaries: List[str], weights: List[float]) -> str: | |
| """Ensemble multiple summaries using semantic similarity""" | |
| try: | |
| if len(summaries) == 1: | |
| return summaries[0] | |
| # Normalize weights | |
| weights = np.array(weights) / np.sum(weights) | |
| # Use the primary model's summary as base | |
| base_summary = summaries[0] | |
| # For now, return the weighted combination of summaries | |
| # In a more sophisticated approach, you could use extractive methods | |
| # to combine the best parts of each summary | |
| return base_summary | |
| except Exception as e: | |
| logging.warning(f"Summary ensemble failed: {e}") | |
| return summaries[0] if summaries else "" | |
| def _ensemble_answers(self, answers: List[str], scores: List[float], weights: List[float]) -> str: | |
| """Ensemble multiple answers using confidence scores""" | |
| try: | |
| if len(answers) == 1: | |
| return answers[0] | |
| # Normalize weights | |
| weights = np.array(weights) / np.sum(weights) | |
| # Weighted voting based on confidence scores | |
| weighted_scores = np.array(scores) * weights | |
| best_index = np.argmax(weighted_scores) | |
| return answers[best_index] | |
| except Exception as e: | |
| logging.warning(f"Answer ensemble failed: {e}") | |
| return answers[0] if answers else "" | |
| def _enhance_answer(self, answer: str, question: str, context: str) -> str: | |
| """Enhance answer with post-processing and validation""" | |
| try: | |
| # Clean the answer | |
| answer = answer.strip() | |
| # Apply legal-specific post-processing | |
| answer = self._apply_legal_postprocessing(answer, question) | |
| # Validate answer against context | |
| if not self._validate_answer_context(answer, context): | |
| # Try to extract a better answer from context | |
| extracted_answer = self._extract_answer_from_context(question, context) | |
| if extracted_answer: | |
| answer = extracted_answer | |
| return answer | |
| except Exception as e: | |
| logging.warning(f"Answer enhancement failed: {e}") | |
| return answer | |
| def _apply_legal_postprocessing(self, answer: str, question: str) -> str: | |
| """Apply legal-specific post-processing rules""" | |
| try: | |
| # Remove common legal document artifacts | |
| answer = re.sub(r'\b(SEC\.|Section|Article)\s*\d+\.?', '', answer, flags=re.IGNORECASE) | |
| answer = re.sub(r'\s+', ' ', answer) | |
| # Handle specific question types | |
| question_lower = question.lower() | |
| if any(word in question_lower for word in ['how long', 'duration', 'period']): | |
| # Extract time-related information | |
| time_match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', answer, re.IGNORECASE) | |
| if time_match: | |
| return time_match.group(0) | |
| elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']): | |
| # Extract monetary information | |
| money_match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', answer) | |
| if money_match: | |
| return money_match.group(0) | |
| elif any(word in question_lower for word in ['when', 'date']): | |
| # Extract date information | |
| date_match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', answer) | |
| if date_match: | |
| return date_match.group(0) | |
| return answer.strip() | |
| except Exception as e: | |
| logging.warning(f"Legal post-processing failed: {e}") | |
| return answer | |
| def _validate_answer_context(self, answer: str, context: str) -> bool: | |
| """Validate if answer is present in context""" | |
| try: | |
| # Simple validation - check if key terms from answer are in context | |
| answer_terms = set(word.lower() for word in answer.split() if len(word) > 3) | |
| context_terms = set(word.lower() for word in context.split()) | |
| # Check if at least 50% of answer terms are in context | |
| if answer_terms: | |
| overlap = len(answer_terms.intersection(context_terms)) / len(answer_terms) | |
| return overlap >= 0.5 | |
| return True | |
| except Exception as e: | |
| logging.warning(f"Answer validation failed: {e}") | |
| return True | |
| def _extract_answer_from_context(self, question: str, context: str) -> Optional[str]: | |
| """Extract answer directly from context using patterns""" | |
| try: | |
| question_lower = question.lower() | |
| if any(word in question_lower for word in ['how long', 'duration', 'period']): | |
| match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', context, re.IGNORECASE) | |
| return match.group(0) if match else None | |
| elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']): | |
| match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', context) | |
| return match.group(0) if match else None | |
| elif any(word in question_lower for word in ['when', 'date']): | |
| match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', context) | |
| return match.group(0) if match else None | |
| return None | |
| except Exception as e: | |
| logging.warning(f"Answer extraction failed: {e}") | |
| return None | |
| def _preprocess_text(self, text: str) -> str: | |
| """Preprocess text for better model performance""" | |
| try: | |
| # Remove common artifacts but preserve legal structure | |
| text = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', text) | |
| text = re.sub(r'<.*?>', ' ', text) | |
| # Preserve legal citations and numbers (don't remove them completely) | |
| # Instead of removing section numbers, normalize them | |
| text = re.sub(r'\b(SEC\.|Section|Article)\s*(\d+)\.?', r'Section \2', text, flags=re.IGNORECASE) | |
| # Clean up excessive whitespace | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| # Preserve important legal punctuation and formatting | |
| text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Ensure proper sentence spacing | |
| # Remove non-printable characters but keep legal symbols | |
| text = re.sub(r'[^\x00-\x7F]+', ' ', text) | |
| # Ensure proper spacing around legal terms | |
| text = re.sub(r'\b(Lessee|Lessor|Party|Parties)\b', r' \1 ', text, flags=re.IGNORECASE) | |
| return text.strip() | |
| except Exception as e: | |
| logging.warning(f"Text preprocessing failed: {e}") | |
| return text | |
| def _chunk_text_for_summarization(self, text: str, max_words: int = 8000) -> str: | |
| """Chunk long text for summarization while preserving legal document structure""" | |
| try: | |
| words = text.split() | |
| if len(words) <= max_words: | |
| return text | |
| # Split into sentences first | |
| sentences = self._split_into_sentences(text) | |
| # Take the most important sentences (first and last portions) | |
| total_sentences = len(sentences) | |
| if total_sentences <= 50: | |
| return text | |
| # Take first 60% and last 20% of sentences | |
| first_portion = int(total_sentences * 0.6) | |
| last_portion = int(total_sentences * 0.2) | |
| selected_sentences = sentences[:first_portion] + sentences[-last_portion:] | |
| chunked_text = " ".join(selected_sentences) | |
| # Ensure we don't exceed token limit | |
| if len(chunked_text.split()) > max_words: | |
| chunked_text = " ".join(chunked_text.split()[:max_words]) | |
| return chunked_text | |
| except Exception as e: | |
| logging.warning(f"Text chunking failed: {e}") | |
| return text | |
| def _handle_long_documents(self, text: str) -> str: | |
| """Handle very long documents by using a sliding window approach""" | |
| try: | |
| # LED-16384 has a context window of ~16k tokens | |
| # Conservative estimate: ~12k tokens for input to leave room for generation | |
| max_tokens = 12000 | |
| # Approximate tokens (roughly 1.3 words per token for English) | |
| words = text.split() | |
| if len(words) <= max_tokens * 0.8: # Conservative limit | |
| return text | |
| # Use sliding window approach for very long documents | |
| sentences = self._split_into_sentences(text) | |
| if len(sentences) < 10: | |
| return text | |
| # Take key sections: beginning, middle, and end | |
| total_sentences = len(sentences) | |
| # Take first 40%, middle 20%, and last 40% | |
| first_end = int(total_sentences * 0.4) | |
| middle_start = int(total_sentences * 0.4) | |
| middle_end = int(total_sentences * 0.6) | |
| last_start = int(total_sentences * 0.6) | |
| key_sentences = ( | |
| sentences[:first_end] + | |
| sentences[middle_start:middle_end] + | |
| sentences[last_start:] | |
| ) | |
| # Ensure we don't exceed token limit | |
| combined_text = " ".join(key_sentences) | |
| words = combined_text.split() | |
| if len(words) > max_tokens * 0.8: | |
| # Truncate to safe limit | |
| combined_text = " ".join(words[:int(max_tokens * 0.8)]) | |
| return combined_text | |
| except Exception as e: | |
| logging.warning(f"Long document handling failed: {e}") | |
| return text | |
| def _ensure_complete_summary(self, summary: str, original_text: str) -> str: | |
| """Ensure the summary is complete and not truncated mid-sentence""" | |
| try: | |
| if not summary: | |
| return summary | |
| # Check if summary ends with complete sentence | |
| if not summary.rstrip().endswith(('.', '!', '?')): | |
| # Find the last complete sentence | |
| sentences = summary.split('. ') | |
| if len(sentences) > 1: | |
| # Remove the incomplete last sentence | |
| summary = '. '.join(sentences[:-1]) + '.' | |
| # Ensure minimum length | |
| if len(summary.split()) < 50: | |
| # Try to extract more content from original text | |
| additional_content = self._extract_key_sentences(original_text, 100) | |
| if additional_content: | |
| summary = summary + " " + additional_content | |
| return summary.strip() | |
| except Exception as e: | |
| logging.warning(f"Summary completion check failed: {e}") | |
| return summary | |
| def _extract_key_sentences(self, text: str, max_words: int = 100) -> str: | |
| """Extract key sentences from text for summary completion""" | |
| try: | |
| sentences = self._split_into_sentences(text) | |
| # Simple heuristic: take sentences with legal keywords | |
| legal_keywords = ['lease', 'rent', 'payment', 'term', 'agreement', 'lessor', 'lessee', | |
| 'covenant', 'obligation', 'right', 'duty', 'termination', 'renewal'] | |
| key_sentences = [] | |
| word_count = 0 | |
| for sentence in sentences: | |
| sentence_lower = sentence.lower() | |
| if any(keyword in sentence_lower for keyword in legal_keywords): | |
| sentence_words = len(sentence.split()) | |
| if word_count + sentence_words <= max_words: | |
| key_sentences.append(sentence) | |
| word_count += sentence_words | |
| else: | |
| break | |
| return " ".join(key_sentences) | |
| except Exception as e: | |
| logging.warning(f"Key sentence extraction failed: {e}") | |
| return "" | |
| def _extractive_summarization(self, text: str, max_length: int) -> str: | |
| """Fallback extractive summarization using TF-IDF""" | |
| try: | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| sentences = self._split_into_sentences(text) | |
| if len(sentences) < 3: | |
| return text | |
| # Create TF-IDF vectors | |
| vectorizer = TfidfVectorizer(stop_words='english', max_features=1000) | |
| tfidf_matrix = vectorizer.fit_transform(sentences) | |
| # Calculate sentence importance based on TF-IDF scores | |
| sentence_scores = [] | |
| for i in range(len(sentences)): | |
| score = tfidf_matrix[i].sum() | |
| sentence_scores.append((score, i)) | |
| # Sort by score and take top sentences | |
| sentence_scores.sort(reverse=True) | |
| # Select sentences up to max_length | |
| selected_indices = [] | |
| total_words = 0 | |
| for score, idx in sentence_scores: | |
| sentence_words = len(sentences[idx].split()) | |
| if total_words + sentence_words <= max_length // 2: # Conservative estimate | |
| selected_indices.append(idx) | |
| total_words += sentence_words | |
| else: | |
| break | |
| # Sort by original order | |
| selected_indices.sort() | |
| summary_sentences = [sentences[i] for i in selected_indices] | |
| return " ".join(summary_sentences) | |
| except Exception as e: | |
| logging.warning(f"Extractive summarization failed: {e}") | |
| return text[:max_length] if len(text) > max_length else text | |
| def _postprocess_summary(self, summary: str, all_summaries: Optional[List[str]] = None, min_sentences: int = 10) -> str: | |
| """Post-process summary for better readability""" | |
| try: | |
| summary = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', summary) | |
| summary = re.sub(r'[^\x00-\x7F]+', ' ', summary) | |
| summary = re.sub(r'\s{2,}', ' ', summary) | |
| # Remove redundant sentences | |
| sentences = summary.split('. ') | |
| unique_sentences = [] | |
| for sentence in sentences: | |
| s = sentence.strip() | |
| if s and s not in unique_sentences: | |
| unique_sentences.append(s) | |
| # If too short, add more unique sentences from other model outputs | |
| if all_summaries is not None and len(unique_sentences) < min_sentences: | |
| all_sentences = [] | |
| for summ in all_summaries: | |
| all_sentences.extend([s.strip() for s in summ.split('. ') if s.strip()]) | |
| for s in all_sentences: | |
| if s not in unique_sentences: | |
| unique_sentences.append(s) | |
| if len(unique_sentences) >= min_sentences: | |
| break | |
| return '. '.join(unique_sentences) | |
| except Exception as e: | |
| logging.warning(f"Summary post-processing failed: {e}") | |
| return summary | |
| def _split_into_sentences(self, text: str) -> List[str]: | |
| """Split text into sentences with improved handling for legal documents""" | |
| try: | |
| # More sophisticated sentence splitting for legal documents | |
| # Handle legal abbreviations and citations properly | |
| text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) | |
| # Split on sentence endings, but be careful with legal citations | |
| sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) | |
| # Clean up sentences | |
| cleaned_sentences = [] | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if sentence and len(sentence) > 10: # Filter out very short fragments | |
| # Handle legal abbreviations that might have been split | |
| if sentence.startswith(('Sec', 'Art', 'Clause', 'Para')): | |
| # This might be a continuation, try to merge with previous | |
| if cleaned_sentences: | |
| cleaned_sentences[-1] = cleaned_sentences[-1] + " " + sentence | |
| else: | |
| cleaned_sentences.append(sentence) | |
| else: | |
| cleaned_sentences.append(sentence) | |
| return cleaned_sentences if cleaned_sentences else [text] | |
| except Exception as e: | |
| logging.warning(f"Sentence splitting failed: {e}") | |
| return [text] | |
| def _calculate_summary_confidence(self, summary: str, original_text: str) -> float: | |
| """Calculate confidence score for summary""" | |
| try: | |
| # Simple confidence based on summary length and content | |
| if not summary or len(summary) < 10: | |
| return 0.0 | |
| # Check if summary contains key terms from original text | |
| summary_terms = set(word.lower() for word in summary.split() if len(word) > 3) | |
| original_terms = set(word.lower() for word in original_text.split() if len(word) > 3) | |
| if original_terms: | |
| overlap = len(summary_terms.intersection(original_terms)) / len(original_terms) | |
| return min(overlap * 2, 1.0) # Scale overlap to 0-1 range | |
| return 0.5 # Default confidence | |
| except Exception as e: | |
| logging.warning(f"Confidence calculation failed: {e}") | |
| return 0.5 | |
| # Global instance | |
| enhanced_model_manager = EnhancedModelManager() |