Kalpokoch commited on
Commit
ce750f8
·
verified ·
1 Parent(s): 7e6e5a8

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +40 -263
app/policy_vector_db.py CHANGED
@@ -1,16 +1,11 @@
1
  import os
2
  import json
3
  import torch
4
- import re
5
- import hashlib
6
- from typing import List, Dict, Optional, Tuple
7
  from sentence_transformers import SentenceTransformer
8
  import chromadb
9
  from chromadb.config import Settings
10
  import logging
11
- import multiprocessing as mp
12
- from concurrent.futures import ThreadPoolExecutor
13
- import numpy as np
14
 
15
  # --- Basic Logging Setup ---
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -18,50 +13,28 @@ logger = logging.getLogger(__name__)
18
 
19
  class PolicyVectorDB:
20
  """
21
- Enhanced vector database for policy documents with metadata-aware search capabilities.
22
- Optimized for CPU utilization.
23
  """
24
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
25
  self.persist_directory = persist_directory
26
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
27
  self.collection_name = "neepco_dop_policies"
28
 
29
- # Optimize CPU usage
30
- self.cpu_count = mp.cpu_count()
31
- torch.set_num_threads(self.cpu_count)
32
-
33
- logger.info(f"Detected {self.cpu_count} CPU cores, optimizing threading...")
34
  logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
35
-
36
- # Optimize model loading for CPU
37
- self.embedding_model = SentenceTransformer(
38
- 'BAAI/bge-large-en-v1.5',
39
- device='cpu',
40
- # Use all available CPU cores for inference
41
- model_kwargs={'torch_dtype': torch.float32}
42
- )
43
-
44
- # Set model to use optimized CPU inference
45
- self.embedding_model.max_seq_length = 512 # Reduce context length for speed
46
-
47
  logger.info("Embedding model loaded successfully.")
48
 
49
- self.collection = None
50
  self.top_k_default = top_k_default
51
  self.relevance_threshold = relevance_threshold
52
-
53
- # Thread pool for parallel processing
54
- self.thread_pool = ThreadPoolExecutor(max_workers=self.cpu_count)
55
-
56
- # Add monetary normalization for queries
57
- self.money_patterns = {
58
- r'(\d+(?:,\d+)*(?:\.\d+)?)\s*crore': lambda x: float(x.replace(',', '')) * 1e7,
59
- r'(\d+(?:,\d+)*(?:\.\d+)?)\s*lakh': lambda x: float(x.replace(',', '')) * 1e5,
60
- r'₹\s*(\d+(?:,\d+)*(?:\.\d+)?)': lambda x: float(x.replace(',', ''))
61
- }
62
 
63
  def _get_collection(self):
64
- """Retrieves or creates the ChromaDB collection. Implements lazy loading."""
 
 
65
  if self.collection is None:
66
  self.collection = self.client.get_or_create_collection(
67
  name=self.collection_name,
@@ -70,90 +43,13 @@ class PolicyVectorDB:
70
  return self.collection
71
 
72
  def _flatten_metadata(self, metadata: Dict) -> Dict:
73
- """Ensures all metadata values are strings, as required by ChromaDB."""
74
- flattened = {}
75
- for key, value in metadata.items():
76
- if isinstance(value, (dict, list)):
77
- # Convert complex structures to JSON strings
78
- flattened[key] = json.dumps(value, ensure_ascii=False)
79
- elif value is not None:
80
- flattened[key] = str(value)
81
- return flattened
82
-
83
- def _extract_query_entities(self, query: str) -> Dict[str, any]:
84
- """Extract structured entities from user queries for better filtering."""
85
- entities = {
86
- 'monetary_values': [],
87
- 'roles': [],
88
- 'sections': [],
89
- 'keywords': []
90
- }
91
-
92
- # Extract monetary amounts
93
- for pattern, converter in self.money_patterns.items():
94
- matches = re.finditer(pattern, query, re.IGNORECASE)
95
- for match in matches:
96
- try:
97
- value = converter(match.group(1))
98
- entities['monetary_values'].append(value)
99
- except:
100
- pass
101
-
102
- # Extract common roles
103
- role_patterns = [
104
- r'\b(CMD|Chairman|Managing Director)\b',
105
- r'\b(Director|D\([PT]\)|D\(P\)|D\(T\))\b',
106
- r'\b(ED|Executive Director)\b',
107
- r'\b(CGM|Chief General Manager)\b',
108
- r'\b(GM|General Manager)\b',
109
- r'\b(DGM|Deputy General Manager)\b',
110
- r'\b(Sr\.?\s*M|Senior Manager)\b'
111
- ]
112
-
113
- for pattern in role_patterns:
114
- matches = re.finditer(pattern, query, re.IGNORECASE)
115
- entities['roles'].extend([match.group() for match in matches])
116
-
117
- # Extract section references
118
- section_matches = re.finditer(r'\b(Section|Annexure)\s*([IVX]+|[A-Z])\b', query, re.IGNORECASE)
119
- entities['sections'].extend([match.group() for match in section_matches])
120
-
121
- return entities
122
-
123
- def _encode_batch_parallel(self, texts: List[str]) -> np.ndarray:
124
- """Parallel encoding of text batches for better CPU utilization."""
125
- # Split texts into smaller batches for parallel processing
126
- batch_size = max(1, len(texts) // self.cpu_count)
127
- if len(texts) <= batch_size:
128
- return self.embedding_model.encode(
129
- texts,
130
- normalize_embeddings=True,
131
- show_progress_bar=False,
132
- batch_size=32, # Optimize batch size for CPU
133
- convert_to_numpy=True
134
- )
135
-
136
- # Process in parallel batches
137
- batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
138
-
139
- def encode_batch(batch):
140
- return self.embedding_model.encode(
141
- batch,
142
- normalize_embeddings=True,
143
- show_progress_bar=False,
144
- batch_size=16,
145
- convert_to_numpy=True
146
- )
147
-
148
- # Use thread pool for parallel encoding
149
- futures = [self.thread_pool.submit(encode_batch, batch) for batch in batches]
150
- results = [future.result() for future in futures]
151
-
152
- # Concatenate results
153
- return np.vstack(results) if results else np.array([])
154
 
155
  def add_chunks(self, chunks: List[Dict]):
156
- """Enhanced chunk addition with better metadata handling and parallel processing."""
 
 
157
  collection = self._get_collection()
158
  if not chunks:
159
  logger.info("No chunks provided to add.")
@@ -174,9 +70,8 @@ class PolicyVectorDB:
174
 
175
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
176
 
177
- # Optimized batch size for CPU processing
178
- batch_size = min(64, max(16, len(new_chunks) // 4))
179
-
180
  for i in range(0, len(new_chunks), batch_size):
181
  batch = new_chunks[i:i + batch_size]
182
 
@@ -184,168 +79,56 @@ class PolicyVectorDB:
184
  texts = [chunk['text'] for chunk in batch]
185
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
186
 
187
- # Use parallel encoding
188
- embeddings = self._encode_batch_parallel(texts).tolist()
189
 
190
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
191
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
192
 
193
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
194
 
195
- def search(self, query_text: str, top_k: int = None, filters: Dict = None) -> List[Dict]:
196
  """
197
- Enhanced search with metadata filtering and entity extraction.
198
- Optimized for CPU performance.
199
  """
200
  collection = self._get_collection()
201
 
202
- # Extract entities from query for potential filtering
203
- entities = self._extract_query_entities(query_text)
204
-
205
- # Build metadata filters
206
- where_conditions = {}
207
- if filters:
208
- where_conditions.update(filters)
209
-
210
- # Add entity-based filters
211
- if entities['roles']:
212
- # Filter by role if mentioned in query
213
- where_conditions["role"] = {"$in": entities['roles']}
214
-
215
- if entities['sections']:
216
- # Filter by section if mentioned
217
- where_conditions["section"] = {"$in": [s.split()[-1] for s in entities['sections']]}
218
-
219
  instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
220
 
221
- # Optimized single query encoding
222
- query_embedding = self.embedding_model.encode(
223
- [instructed_query],
224
- normalize_embeddings=True,
225
- show_progress_bar=False,
226
- batch_size=1,
227
- convert_to_numpy=True
228
- ).tolist()
229
 
230
  k = top_k if top_k is not None else self.top_k_default
231
 
232
- # Perform search with metadata filtering
233
- search_params = {
234
- "query_embeddings": query_embedding,
235
- "n_results": k * 3, # Get more for filtering
236
- "include": ["documents", "metadatas", "distances"]
237
- }
238
-
239
- if where_conditions:
240
- search_params["where"] = where_conditions
241
-
242
- results = collection.query(**search_params)
243
 
244
  search_results = []
245
  if results and results.get('documents') and results['documents'][0]:
246
  for i, doc in enumerate(results['documents'][0]):
 
247
  relevance_score = 1 - results['distances'][0][i]
248
 
249
  if relevance_score >= self.relevance_threshold:
250
- result = {
251
  'text': doc,
252
  'metadata': results['metadatas'][0][i],
253
  'relevance_score': relevance_score
254
- }
255
-
256
- # Add monetary filtering if amounts mentioned in query
257
- if entities['monetary_values'] and 'limit_normalized' in results['metadatas'][0][i]:
258
- try:
259
- chunk_limit = float(results['metadatas'][0][i]['limit_normalized'])
260
- query_amount = max(entities['monetary_values'])
261
-
262
- # Boost relevance if the limit is appropriate for the query amount
263
- if chunk_limit >= query_amount:
264
- result['relevance_score'] += 0.1 # Small boost for relevant limits
265
- except:
266
- pass
267
-
268
- search_results.append(result)
269
-
270
- return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
271
-
272
- def search_with_context(self, query_text: str, top_k: int = None, include_related: bool = True) -> List[Dict]:
273
- """
274
- Search with automatic inclusion of related/parent chunks for better context.
275
- """
276
- primary_results = self.search(query_text, top_k)
277
-
278
- if not include_related or not primary_results:
279
- return primary_results
280
-
281
- # Find related chunks based on parent_id relationships
282
- related_ids = set()
283
- for result in primary_results:
284
- metadata = result['metadata']
285
- parent_id = metadata.get('parent_id')
286
- if parent_id:
287
- related_ids.add(parent_id)
288
-
289
- if related_ids:
290
- collection = self._get_collection()
291
- try:
292
- related_chunks = collection.get(
293
- ids=list(related_ids),
294
- include=["documents", "metadatas"]
295
- )
296
-
297
- for i, doc in enumerate(related_chunks['documents']):
298
- primary_results.append({
299
- 'text': doc,
300
- 'metadata': related_chunks['metadatas'][i],
301
- 'relevance_score': 0.3, # Lower score for context
302
- 'is_context': True
303
  })
304
- except Exception as e:
305
- logger.warning(f"Could not retrieve related chunks: {e}")
306
 
307
- return sorted(primary_results, key=lambda x: x['relevance_score'], reverse=True)
308
-
309
- def search_by_amount(self, amount: float, comparison: str = ">=", top_k: int = None) -> List[Dict]:
310
- """Search for delegation limits based on monetary amount."""
311
- collection = self._get_collection()
312
-
313
- where_condition = {}
314
- if comparison == ">=":
315
- where_condition = {"limit_normalized": {"$gte": amount}}
316
- elif comparison == "<=":
317
- where_condition = {"limit_normalized": {"$lte": amount}}
318
- elif comparison == "==":
319
- where_condition = {"limit_normalized": {"$eq": amount}}
320
-
321
- try:
322
- results = collection.get(
323
- where=where_condition,
324
- include=["documents", "metadatas"]
325
- )
326
-
327
- search_results = []
328
- if results and results.get('documents'):
329
- for i, doc in enumerate(results['documents']):
330
- search_results.append({
331
- 'text': doc,
332
- 'metadata': results['metadatas'][i],
333
- 'relevance_score': 1.0 # Perfect match for structured query
334
- })
335
-
336
- k = top_k if top_k is not None else self.top_k_default
337
- return search_results[:k]
338
- except Exception as e:
339
- logger.warning(f"Error in search_by_amount: {e}")
340
- return []
341
-
342
- def __del__(self):
343
- """Cleanup thread pool on deletion."""
344
- if hasattr(self, 'thread_pool'):
345
- self.thread_pool.shutdown(wait=False)
346
 
347
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
348
- """Checks if the DB is empty and populates it from a JSONL file if needed."""
 
 
349
  try:
350
  if db_instance._get_collection().count() > 0:
351
  logger.info("Vector database already contains data. Skipping population.")
@@ -368,15 +151,9 @@ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> b
368
  logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
369
  return False
370
 
371
- # Process in batches to avoid memory issues
372
- batch_size = 500
373
- for i in range(0, len(chunks_to_add), batch_size):
374
- batch = chunks_to_add[i:i + batch_size]
375
- db_instance.add_chunks(batch)
376
- logger.info(f"Processed batch {i//batch_size + 1}/{(len(chunks_to_add) + batch_size - 1) // batch_size}")
377
-
378
  logger.info("Vector database population attempt complete.")
379
  return True
380
  except Exception as e:
381
  logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
382
- return False
 
1
  import os
2
  import json
3
  import torch
4
+ from typing import List, Dict
 
 
5
  from sentence_transformers import SentenceTransformer
6
  import chromadb
7
  from chromadb.config import Settings
8
  import logging
 
 
 
9
 
10
  # --- Basic Logging Setup ---
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
13
 
14
  class PolicyVectorDB:
15
  """
16
+ Manages the connection, population, and querying of a ChromaDB vector database
17
+ for policy documents.
18
  """
19
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
20
  self.persist_directory = persist_directory
21
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
22
  self.collection_name = "neepco_dop_policies"
23
 
24
+ # Using a powerful open-source embedding model.
25
+ # Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
 
 
 
26
  logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
27
+ self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
 
 
 
 
 
 
 
 
 
 
 
28
  logger.info("Embedding model loaded successfully.")
29
 
30
+ self.collection = None # Initialize collection as None for lazy loading
31
  self.top_k_default = top_k_default
32
  self.relevance_threshold = relevance_threshold
 
 
 
 
 
 
 
 
 
 
33
 
34
  def _get_collection(self):
35
+ """
36
+ Retrieves or creates the ChromaDB collection. Implements lazy loading.
37
+ """
38
  if self.collection is None:
39
  self.collection = self.client.get_or_create_collection(
40
  name=self.collection_name,
 
43
  return self.collection
44
 
45
  def _flatten_metadata(self, metadata: Dict) -> Dict:
46
+ """Ensures all metadata values are strings, as required by some ChromaDB versions."""
47
+ return {key: str(value) for key, value in metadata.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def add_chunks(self, chunks: List[Dict]):
50
+ """
51
+ Adds a list of chunks to the vector database, skipping any that already exist.
52
+ """
53
  collection = self._get_collection()
54
  if not chunks:
55
  logger.info("No chunks provided to add.")
 
70
 
71
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
72
 
73
+ # Process in batches for efficiency
74
+ batch_size = 32 # Reduced batch size for potentially large embeddings
 
75
  for i in range(0, len(new_chunks), batch_size):
76
  batch = new_chunks[i:i + batch_size]
77
 
 
79
  texts = [chunk['text'] for chunk in batch]
80
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
81
 
82
+ # For BGE models, it's recommended not to add instructions to the document embeddings
83
+ embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
84
 
85
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
86
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
87
 
88
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
89
 
90
+ def search(self, query_text: str, top_k: int = None) -> List[Dict]:
91
  """
92
+ Searches the vector database for a given query text.
93
+ Returns a list of results filtered by a relevance threshold.
94
  """
95
  collection = self._get_collection()
96
 
97
+ # IMPROVEMENT: Add the recommended instruction prefix for BGE retrieval models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
99
 
100
+ # IMPROVEMENT: Normalize embeddings for more accurate similarity search.
101
+ query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
 
 
 
 
 
 
102
 
103
  k = top_k if top_k is not None else self.top_k_default
104
 
105
+ # Retrieve more results initially to allow for filtering
106
+ results = collection.query(
107
+ query_embeddings=query_embedding,
108
+ n_results=k * 2, # Retrieve more to filter by threshold
109
+ include=["documents", "metadatas", "distances"]
110
+ )
 
 
 
 
 
111
 
112
  search_results = []
113
  if results and results.get('documents') and results['documents'][0]:
114
  for i, doc in enumerate(results['documents'][0]):
115
+ # The distance for normalized embeddings is often interpreted as 1 - cosine_similarity
116
  relevance_score = 1 - results['distances'][0][i]
117
 
118
  if relevance_score >= self.relevance_threshold:
119
+ search_results.append({
120
  'text': doc,
121
  'metadata': results['metadatas'][0][i],
122
  'relevance_score': relevance_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  })
 
 
124
 
125
+ # Sort by relevance score and return the top_k results
126
+ return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
129
+ """
130
+ Checks if the DB is empty and populates it from a JSONL file if needed.
131
+ """
132
  try:
133
  if db_instance._get_collection().count() > 0:
134
  logger.info("Vector database already contains data. Skipping population.")
 
151
  logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
152
  return False
153
 
154
+ db_instance.add_chunks(chunks_to_add)
 
 
 
 
 
 
155
  logger.info("Vector database population attempt complete.")
156
  return True
157
  except Exception as e:
158
  logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
159
+ return False