Spaces:
Runtime error
Runtime error
| from typing import List, Literal, Optional, TypedDict | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langgraph.graph import END, START, StateGraph | |
| from pydantic import BaseModel, Field | |
| from qdrant_client.http.models import ( | |
| FieldCondition, | |
| Filter, | |
| MatchValue, | |
| ) | |
| from clients import LLM, VECTOR_STORE | |
| class RetrievalState(TypedDict): | |
| """State for the agentic retrieval graph.""" | |
| original_query: str | |
| current_query: str | |
| category: Optional[str] | |
| topic: Optional[str] | |
| documents: List[Document] | |
| relevant_documents: List[Document] | |
| generation: str | |
| retry_count: int | |
| max_retries: int | |
| class GradeDocuments(BaseModel): | |
| """Grade whether a document is relevant to the query.""" | |
| is_relevant: Literal["yes", "no"] = Field( | |
| description="Is the document relevant to the query? 'yes' or 'no'" | |
| ) | |
| reason: str = Field(description="Brief reason for the relevance decision") | |
| def retrieve_documents(state: RetrievalState) -> RetrievalState: | |
| """Retrieve documents from vector store.""" | |
| query = state["current_query"] | |
| category = state.get("category") | |
| topic = state.get("topic") | |
| # Build Qdrant filter | |
| conditions = [] | |
| if category: | |
| conditions.append( | |
| FieldCondition(key="metadata.category", match=MatchValue(value=category)) | |
| ) | |
| if topic: | |
| conditions.append( | |
| FieldCondition(key="metadata.topic", match=MatchValue(value=topic)) | |
| ) | |
| qdrant_filter = Filter(must=conditions) if conditions else None | |
| documents = VECTOR_STORE.similarity_search( | |
| query, | |
| k=5, | |
| filter=qdrant_filter, | |
| ) | |
| return {**state, "documents": documents} | |
| def grade_documents(state: RetrievalState) -> RetrievalState: | |
| """Grade documents for relevance using LLM.""" | |
| query = state["original_query"] | |
| documents = state["documents"] | |
| if not documents: | |
| return {**state, "relevant_documents": []} | |
| # Create grader with structured output | |
| grader_llm = LLM.with_structured_output(GradeDocuments) | |
| grading_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are a grader assessing relevance of a retrieved document to a user query. | |
| If the document contains keywords or semantic meaning related to the query, grade it as relevant. | |
| Be lenient - even partial relevance should be marked as 'yes'. | |
| Only mark 'no' if the document is completely unrelated.""", | |
| ), | |
| ( | |
| "human", | |
| """Query: {query} | |
| Document content: {document} | |
| Is this document relevant to the query?""", | |
| ), | |
| ] | |
| ) | |
| relevant_docs = [] | |
| for doc in documents: | |
| try: | |
| result = grader_llm.invoke( | |
| grading_prompt.format_messages( | |
| query=query, | |
| document=doc.page_content[:1000], # Limit content length | |
| ) | |
| ) | |
| if result.is_relevant == "yes": | |
| relevant_docs.append(doc) | |
| except Exception: | |
| # If grading fails, include the document (fail-safe) | |
| relevant_docs.append(doc) | |
| return {**state, "relevant_documents": relevant_docs} | |
| def rewrite_query(state: RetrievalState) -> RetrievalState: | |
| """Rewrite the query for better retrieval.""" | |
| original_query = state["original_query"] | |
| retry_count = state["retry_count"] | |
| rewrite_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are an expert at reformulating search queries. | |
| Given the original query, generate a better search query that might retrieve more relevant documents. | |
| Focus on: | |
| - Extracting key concepts and entities | |
| - Using synonyms or related terms | |
| - Being more specific or more general as appropriate | |
| Return ONLY the rewritten query, nothing else.""", | |
| ), | |
| ("human", "Original query: {query}\n\nRewritten query:"), | |
| ] | |
| ) | |
| response = LLM.invoke(rewrite_prompt.format_messages(query=original_query)) | |
| new_query = response.content.strip() | |
| return { | |
| **state, | |
| "current_query": new_query, | |
| "retry_count": retry_count + 1, | |
| } | |
| def generate_response(state: RetrievalState) -> RetrievalState: | |
| """Generate final response from relevant documents.""" | |
| relevant_docs = state["relevant_documents"] | |
| if not relevant_docs: | |
| return {**state, "generation": "No relevant memories found."} | |
| # Format documents | |
| formatted = [] | |
| for i, doc in enumerate(relevant_docs, 1): | |
| meta = doc.metadata | |
| formatted.append( | |
| f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}" | |
| ) | |
| return {**state, "generation": "\n".join(formatted)} | |
| def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]: | |
| """Decide whether to retry with a rewritten query.""" | |
| relevant_docs = state["relevant_documents"] | |
| retry_count = state["retry_count"] | |
| max_retries = state["max_retries"] | |
| # If we have relevant docs, generate response | |
| if relevant_docs: | |
| return "generate" | |
| # If no relevant docs and we can still retry, rewrite query | |
| if retry_count < max_retries: | |
| return "rewrite" | |
| # Max retries reached, generate (empty) response | |
| return "generate" | |
| def build_retrieval_graph(): | |
| workflow = StateGraph(RetrievalState) | |
| # Add nodes | |
| workflow.add_node("retrieve", retrieve_documents) | |
| workflow.add_node("grade", grade_documents) | |
| workflow.add_node("rewrite", rewrite_query) | |
| workflow.add_node("generate", generate_response) | |
| # Add edges | |
| workflow.add_edge(START, "retrieve") | |
| workflow.add_edge("retrieve", "grade") | |
| workflow.add_conditional_edges( | |
| "grade", | |
| should_retry, | |
| { | |
| "rewrite": "rewrite", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("rewrite", "retrieve") | |
| workflow.add_edge("generate", END) | |
| return workflow.compile() | |