File size: 10,775 Bytes
dc084a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# ======================================================================================
# 1. SETUP: Patch SQLite and Import Libraries
# ======================================================================================
# This MUST be the first import to ensure ChromaDB uses the correct SQLite version
import sys
import os
os.environ['PYSQLITE3_BUNDLED'] = '1'
__import__('pysqlite3')
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

# Standard and third-party libraries
import json
import pandas as pd
from typing import List, Union

import chromadb

import gradio as gr
from pydantic import BaseModel, ValidationError
from sentence_transformers import SentenceTransformer, CrossEncoder

# LangChain imports
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain.output_parsers import PydanticOutputParser
from langchain_community.embeddings import SentenceTransformerEmbeddings

# ======================================================================================
# 2. CONSTANTS AND CONFIGURATION
# ======================================================================================
DB_DIR = "./chroma_db"
COLLECTION_NAME = "clinical_examples"
EMBEDDING_MODEL_NAME = "pritamdeka/S-Biomed-Roberta-snli-multinli-stsb"
RERANKER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
DATASET_URL = "https://huggingface.co/datasets/DanFed/patient_encounters1_notes_preprocessed/raw/main/patient_encounters1_notes_preprocessed.csv"


# ======================================================================================
# 3. DATABASE SETUP: One-time data loading and embedding
# ======================================================================================
def setup_database(client: chromadb.Client):
    """
    Loads data, generates embeddings, and populates the ChromaDB collection
    only if it's empty.
    """
    collection = client.get_or_create_collection(name=COLLECTION_NAME)

    if collection.count() > 0:
        print(f"Collection '{COLLECTION_NAME}' already exists with {collection.count()} documents. Skipping setup.")
        return

    print(f"Collection '{COLLECTION_NAME}' is empty. Starting data population...")

    # Load dataset
    df = pd.read_csv(DATASET_URL)
    df.drop(['index', 'ENCOUNTER_ID', 'CLINICAL_NOTES', 'BIRTHDATE', 'FIRST',
           'START', 'STOP', 'PATIENT_ID', 'ENCOUNTERCLASS', 'CODE', 'DESCRIPTION',
           'BASE_ENCOUNTER_COST', 'TOTAL_CLAIM_COST', 'PAYER_COVERAGE',
           'REASONCODE', 'REASONDESCRIPTION', 'PATIENT_AGE',
           'DESCRIPTION_OBSERVATIONS', 'DESCRIPTION_CONDITIONS',
           'DESCRIPTION_MEDICATIONS', 'DESCRIPTION_PROCEDURES', 'AGE_GROUP'], axis=1, inplace=True)

    # Create example strings
    def create_examples(row):
        return f"Message: \n\n{row['ENCOUNTER_PROMPT'].strip()}\n\nResult: \n\n{row['COND_MED_PRO_STRUCTURED'].strip()}"
    df['EXAMPLES'] = df.apply(create_examples, axis=1)

    # Generate embeddings
    model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    examples = df["EXAMPLES"].tolist()
    embeddings = model.encode(
        examples,
        batch_size=32,
        show_progress_bar=True,
        convert_to_numpy=True
    )

    # Add to collection
    collection.add(
        documents=df["EXAMPLES"].tolist(),
        embeddings=embeddings.tolist(),
        ids=[str(i) for i in range(len(df["EXAMPLES"]))]
    )
    print(f"Successfully added {len(df['EXAMPLES'])} documents to the '{COLLECTION_NAME}' collection.")


# ======================================================================================
# 4. APPLICATION GLOBALS AND AI COMPONENTS
# ======================================================================================
# Pydantic schema for structured output
class ClinicalExtraction(BaseModel):
    conditions: List[str]
    medications: List[str]
    procedures: List[str]

# Parser and format instructions
parser = PydanticOutputParser(pydantic_object=ClinicalExtraction)
format_instructions = parser.get_format_instructions().replace("{", "{{").replace("}", "}}")

# Global variables for AI components
LANGCHAIN_LLM = None
FINAL_PROMPT = None
FINAL_CHAIN = None
VECTOR_STORE = None
RERANKER = CrossEncoder(RERANKER_MODEL_NAME)

def initialize_ai_components(api_key: str):
    """Initializes all AI components needed for the RAG pipeline."""
    global LANGCHAIN_LLM, FINAL_PROMPT, FINAL_CHAIN
    if not api_key:
        raise gr.Error("OpenAI API Key is required!")

    # LLM
    LANGCHAIN_LLM = ChatOpenAI(openai_api_key=api_key, temperature=0.2)

    # Prompt Template
    FINAL_PROMPT = ChatPromptTemplate.from_template(
        f"""You are a clinical information extractor.
Extract EXACTLY this JSON format and nothing else:

{format_instructions}

CONTEXT (examples):

{{context}}

INPUT MESSAGE (clinical note + surrounding metadata):

{{input}}

Result:"""
    )

    # RAG Chain
    FINAL_CHAIN = (
        {"context": RunnablePassthrough(), "input": RunnablePassthrough()}
        | FINAL_PROMPT
        | LANGCHAIN_LLM
        | StrOutputParser()
    )
    return "<p style='color:green;'>AI components initialized successfully!</p>"

# ======================================================================================
# 5. RAG PIPELINE
# ======================================================================================
def format_docs(docs):
    """Join doc.page_content with blank lines."""
    return "\n\n".join(d.page_content for d in docs)

def generate_rag_response(input_text: str) -> Union[dict, str]:
    """
    Performs retrieval, reranking, generation, and validation.
    """
    if not FINAL_CHAIN or not VECTOR_STORE:
        return "Error: AI components not initialized. Please set your API key."

    # Initial embedding retrieval (top 20)
    retriever = VECTOR_STORE.as_retriever(search_kwargs={"k": 20})
    candidates = retriever.get_relevant_documents(input_text)

    # Cross-encoder rerank -> top 5
    pairs  = [(input_text, d.page_content) for d in candidates]
    scores = RERANKER.predict(pairs)
    sorted_docs = [d for _, d in sorted(zip(scores, candidates), reverse=True)]
    top_docs = sorted_docs[:5]

    # Build context and invoke chain
    context = format_docs(top_docs)
    raw_output = FINAL_CHAIN.invoke({"context": context, "input": input_text})

    # Parse and validate the output
    try:
        parsed = parser.parse(raw_output)
        return parsed.dict()
    except ValidationError as e:
        return f"Schema validation failed: {e}. Raw output was: {raw_output}"

# ======================================================================================
# 6. GRADIO UI
# ======================================================================================
def create_gradio_ui():
    """Defines and returns the Gradio UI blocks."""
    with gr.Blocks(title="Clinical Information Extractor") as demo:
        gr.Markdown("# Clinical Information Extractor with RAG and Reranking")

        with gr.Accordion("API Key Configuration", open=True):
            key_box = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
            key_btn = gr.Button("Set API Key")
            key_status = gr.Markdown("")
        key_btn.click(initialize_ai_components, inputs=[key_box], outputs=[key_status])

        gr.Markdown("---")
        gr.Markdown("## Enter Clinical Note and Metadata")

        with gr.Row():
            age_group_input = gr.Textbox(label="Age Group", placeholder="e.g., middle adulthood")
            visit_type_input = gr.Textbox(label="Visit Type", placeholder="e.g., ambulatory")
        description_input = gr.Textbox(label="Description", placeholder="e.g., encounter for check up (procedure)")
        note_input = gr.Textbox(label="Clinical Note", placeholder="Type the clinical note here...", lines=5)
        
        chatbot = gr.Chatbot(label="Extraction History", height=400)
        send_btn = gr.Button("➡️ Extract Information")

        def chat_interface(age, visit, desc, note, history):
            history = history or []
            
            # Build full input with metadata
            metadata_parts = []
            if age: metadata_parts.append(f"Age group: {age}")
            if visit: metadata_parts.append(f"Visit type: {visit}")
            if desc: metadata_parts.append(f"Description: {desc}")
            metadata_str = " | ".join(metadata_parts)
            
            full_input = f"{metadata_str}\n\nClinical Note:\n{note}" if metadata_str else note
            user_display = f"**Metadata**: {metadata_str}\n\n**Note**: {note}"
            
            # Get response from RAG pipeline
            response = generate_rag_response(full_input)
            
            # Format bot response
            if isinstance(response, dict):
                bot_response = f"```json\n{json.dumps(response, indent=2)}\n```"
            else:
                bot_response = str(response)
            
            history.append((user_display, bot_response))
            return history, "" # Return updated history and clear the input textbox

        send_btn.click(
            fn=chat_interface,
            inputs=[age_group_input, visit_type_input, description_input, note_input, chatbot],
            outputs=[chatbot, note_input]
        )
        note_input.submit(
            fn=chat_interface,
            inputs=[age_group_input, visit_type_input, description_input, note_input, chatbot],
            outputs=[chatbot, note_input]
        )
    return demo

# ======================================================================================
# 7. MAIN EXECUTION
# ======================================================================================
def main():
    """
    Main function to set up the database, initialize components, and launch the UI.
    """
    global VECTOR_STORE
    
    # 1. Setup ChromaDB client
    client = chromadb.PersistentClient(path=DB_DIR)
    
    # 2. Populate the database if needed
    setup_database(client)
    
    # 3. Initialize the LangChain vector store wrapper
    embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
    VECTOR_STORE = Chroma(
        client=client,
        collection_name=COLLECTION_NAME,
        embedding_function=embeddings,
    )
    print(f"Vector store initialized with {VECTOR_STORE._collection.count()} documents.")
    
    # 4. Create and launch the Gradio UI
    demo = create_gradio_ui()
    print("Launching Clinical IE Demo...")
    demo.launch(server_name="0.0.0.0")

if __name__ == "__main__":
    main()