Spaces:
Runtime error
Runtime error
Update embedding model to Google Generative AI and enhance vector store functionality
Browse files- Changed the embedding model from OpenAI to Google Generative AI in the EmbeddingManager class.
- Updated the configuration to reflect the new embedding model path.
- Modified validation checks to ensure the presence of the Google API key for RAG embeddings.
- Added a new method to reset the vector store, allowing for a complete clear and recreation of the collection.
- Enhanced logging to provide clearer feedback on embedding model initialization and vector store operations.
- src/core/config.py +3 -3
- src/rag/embeddings.py +17 -19
- src/rag/vector_store.py +32 -1
src/core/config.py
CHANGED
|
@@ -86,7 +86,7 @@ class RAGConfig:
|
|
| 86 |
chat_history_path: str = "./data/chat_history"
|
| 87 |
|
| 88 |
# Embedding settings
|
| 89 |
-
embedding_model: str = "text-embedding-
|
| 90 |
embedding_chunk_size: int = 1000
|
| 91 |
|
| 92 |
# Chunking settings
|
|
@@ -182,8 +182,8 @@ class Config:
|
|
| 182 |
validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
|
| 183 |
|
| 184 |
# Check RAG dependencies
|
| 185 |
-
if not self.api.
|
| 186 |
-
validation_results["warnings"].append("
|
| 187 |
|
| 188 |
if not self.api.google_api_key:
|
| 189 |
validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
|
|
|
|
| 86 |
chat_history_path: str = "./data/chat_history"
|
| 87 |
|
| 88 |
# Embedding settings
|
| 89 |
+
embedding_model: str = "models/text-embedding-004"
|
| 90 |
embedding_chunk_size: int = 1000
|
| 91 |
|
| 92 |
# Chunking settings
|
|
|
|
| 182 |
validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
|
| 183 |
|
| 184 |
# Check RAG dependencies
|
| 185 |
+
if not self.api.google_api_key:
|
| 186 |
+
validation_results["warnings"].append("Google API key not found - RAG embeddings will be unavailable")
|
| 187 |
|
| 188 |
if not self.api.google_api_key:
|
| 189 |
validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
|
src/rag/embeddings.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
from typing import Optional
|
| 5 |
-
from
|
| 6 |
from src.core.config import config
|
| 7 |
from src.core.logging_config import get_logger
|
| 8 |
|
|
@@ -12,30 +12,28 @@ class EmbeddingManager:
|
|
| 12 |
"""Manages embedding models for document vectorization."""
|
| 13 |
|
| 14 |
def __init__(self):
|
| 15 |
-
self._embedding_model: Optional[
|
| 16 |
|
| 17 |
-
def get_embedding_model(self) ->
|
| 18 |
-
"""Get or create the
|
| 19 |
if self._embedding_model is None:
|
| 20 |
try:
|
| 21 |
-
# Get
|
| 22 |
-
|
| 23 |
|
| 24 |
-
if not
|
| 25 |
-
raise ValueError("
|
| 26 |
|
| 27 |
-
self._embedding_model =
|
| 28 |
-
model=
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
max_retries=3,
|
| 32 |
-
timeout=30
|
| 33 |
)
|
| 34 |
|
| 35 |
-
logger.info("
|
| 36 |
|
| 37 |
except Exception as e:
|
| 38 |
-
logger.error(f"Failed to initialize
|
| 39 |
raise
|
| 40 |
|
| 41 |
return self._embedding_model
|
|
@@ -50,14 +48,14 @@ class EmbeddingManager:
|
|
| 50 |
|
| 51 |
# Check if we got a valid embedding (list of floats)
|
| 52 |
if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
|
| 53 |
-
logger.info("
|
| 54 |
return True
|
| 55 |
else:
|
| 56 |
-
logger.error("
|
| 57 |
return False
|
| 58 |
|
| 59 |
except Exception as e:
|
| 60 |
-
logger.error(f"
|
| 61 |
return False
|
| 62 |
|
| 63 |
# Global embedding manager instance
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
from typing import Optional
|
| 5 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 6 |
from src.core.config import config
|
| 7 |
from src.core.logging_config import get_logger
|
| 8 |
|
|
|
|
| 12 |
"""Manages embedding models for document vectorization."""
|
| 13 |
|
| 14 |
def __init__(self):
|
| 15 |
+
self._embedding_model: Optional[GoogleGenerativeAIEmbeddings] = None
|
| 16 |
|
| 17 |
+
def get_embedding_model(self) -> GoogleGenerativeAIEmbeddings:
|
| 18 |
+
"""Get or create the Gemini embedding model."""
|
| 19 |
if self._embedding_model is None:
|
| 20 |
try:
|
| 21 |
+
# Get Google API key from config/environment
|
| 22 |
+
google_api_key = config.api.google_api_key or os.getenv("GOOGLE_API_KEY")
|
| 23 |
|
| 24 |
+
if not google_api_key:
|
| 25 |
+
raise ValueError("Google API key not found. Please set GOOGLE_API_KEY in environment variables.")
|
| 26 |
|
| 27 |
+
self._embedding_model = GoogleGenerativeAIEmbeddings(
|
| 28 |
+
model=config.rag.embedding_model,
|
| 29 |
+
google_api_key=google_api_key,
|
| 30 |
+
task_type="RETRIEVAL_DOCUMENT"
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
+
logger.info(f"Gemini embedding model ({config.rag.embedding_model}) initialized successfully")
|
| 34 |
|
| 35 |
except Exception as e:
|
| 36 |
+
logger.error(f"Failed to initialize Gemini embedding model: {e}")
|
| 37 |
raise
|
| 38 |
|
| 39 |
return self._embedding_model
|
|
|
|
| 48 |
|
| 49 |
# Check if we got a valid embedding (list of floats)
|
| 50 |
if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
|
| 51 |
+
logger.info("Gemini embedding model test successful")
|
| 52 |
return True
|
| 53 |
else:
|
| 54 |
+
logger.error("Gemini embedding model test failed: Invalid embedding format")
|
| 55 |
return False
|
| 56 |
|
| 57 |
except Exception as e:
|
| 58 |
+
logger.error(f"Gemini embedding model test failed: {e}")
|
| 59 |
return False
|
| 60 |
|
| 61 |
# Global embedding manager instance
|
src/rag/vector_store.py
CHANGED
|
@@ -70,6 +70,7 @@ class VectorStoreManager:
|
|
| 70 |
|
| 71 |
logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
|
| 72 |
|
|
|
|
| 73 |
def get_vector_store(self) -> Chroma:
|
| 74 |
"""Get or create the Chroma vector store."""
|
| 75 |
if self._vector_store is None:
|
|
@@ -314,7 +315,7 @@ class VectorStoreManager:
|
|
| 314 |
"collection_name": self.collection_name,
|
| 315 |
"persist_directory": self.persist_directory,
|
| 316 |
"document_count": count,
|
| 317 |
-
"embedding_model":
|
| 318 |
}
|
| 319 |
|
| 320 |
logger.info(f"Collection info: {info}")
|
|
@@ -371,6 +372,36 @@ class VectorStoreManager:
|
|
| 371 |
logger.error(f"Error searching with metadata filter: {e}")
|
| 372 |
return []
|
| 373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
def clear_all_documents(self) -> bool:
|
| 375 |
"""
|
| 376 |
Clear all documents from the vector store collection.
|
|
|
|
| 70 |
|
| 71 |
logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
|
| 72 |
|
| 73 |
+
|
| 74 |
def get_vector_store(self) -> Chroma:
|
| 75 |
"""Get or create the Chroma vector store."""
|
| 76 |
if self._vector_store is None:
|
|
|
|
| 315 |
"collection_name": self.collection_name,
|
| 316 |
"persist_directory": self.persist_directory,
|
| 317 |
"document_count": count,
|
| 318 |
+
"embedding_model": config.rag.embedding_model
|
| 319 |
}
|
| 320 |
|
| 321 |
logger.info(f"Collection info: {info}")
|
|
|
|
| 372 |
logger.error(f"Error searching with metadata filter: {e}")
|
| 373 |
return []
|
| 374 |
|
| 375 |
+
def reset_vector_store(self) -> bool:
|
| 376 |
+
"""
|
| 377 |
+
Reset the vector store completely.
|
| 378 |
+
This will clear all documents and recreate the collection.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
True if successful, False otherwise
|
| 382 |
+
"""
|
| 383 |
+
try:
|
| 384 |
+
logger.info("Resetting vector store...")
|
| 385 |
+
|
| 386 |
+
# Clear all documents and reset the vector store
|
| 387 |
+
success = self.clear_all_documents()
|
| 388 |
+
|
| 389 |
+
if success:
|
| 390 |
+
# Also delete the collection to ensure clean state
|
| 391 |
+
if self._vector_store is not None:
|
| 392 |
+
self._vector_store.delete_collection()
|
| 393 |
+
self._vector_store = None
|
| 394 |
+
|
| 395 |
+
logger.info("Vector store reset successfully")
|
| 396 |
+
return True
|
| 397 |
+
else:
|
| 398 |
+
logger.error("Failed to reset vector store")
|
| 399 |
+
return False
|
| 400 |
+
|
| 401 |
+
except Exception as e:
|
| 402 |
+
logger.error(f"Error resetting vector store: {e}")
|
| 403 |
+
return False
|
| 404 |
+
|
| 405 |
def clear_all_documents(self) -> bool:
|
| 406 |
"""
|
| 407 |
Clear all documents from the vector store collection.
|