hopsworks_chat / app.py
Sebastian Schmülling
changed title
c4ef8ee
import subprocess
subprocess.run("pip install llama-cpp-python==0.3.15", shell=True, check=True)
import gradio as gr
import hopsworks
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import faiss
import numpy as np
import os
import json
import yaml
from dotenv import load_dotenv
# 1. Load Environment Variables & Validation
load_dotenv()
HOPSWORKS_API_KEY = os.getenv("HOPSWORKS_API_KEY")
if not HOPSWORKS_API_KEY:
raise ValueError("HOPSWORKS_API_KEY not found in environment variables.")
# Load models configuration
with open("models_config.json", "r") as f:
models_config = json.load(f)
# Load RAG prompt configuration
with open("prompts/rag_prompt.yml", "r") as f:
prompt_config = yaml.safe_load(f)
# Global variable to store the current LLM
llm = None
print("Initializing embeddings and connecting to Hopsworks...")
try:
embeddings = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
project = hopsworks.login(api_key_value=HOPSWORKS_API_KEY)
fs = project.get_feature_store()
book_fg = fs.get_feature_group("book_embeddings", version=1)
df = book_fg.read()
if df.empty:
raise ValueError("Feature group 'book_embeddings' is empty.")
texts = df['text'].tolist()
raw_embeddings = [emb if isinstance(emb, list) else emb.tolist() for emb in df['embedding']]
embedding_vectors = np.array(raw_embeddings, dtype='float32')
dimension = embedding_vectors.shape[1]
index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embedding_vectors)
index.add(embedding_vectors)
print("Embeddings and FAISS index initialized.")
except Exception as e:
print(f"Critical Error during initialization: {e}")
index = None
# Function to load a model dynamically
def load_model(repo_name, model_name, progress=gr.Progress()):
global llm
try:
progress(0, desc="Initializing...")
# Find the repository
repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
if not repo:
return f"Error: Repository '{repo_name}' not found in config."
# Find the model within the repository
model = next((m for m in repo["models"] if m["name"] == model_name), None)
if not model:
return f"Error: Model '{model_name}' not found in repository."
print(f"Loading model: {model['name']}...")
print(f"Repo: {repo['repo_id']}, File: {model['filename']}")
progress(0.3, desc=f"Downloading/Loading {model['name']}...")
llm = Llama.from_pretrained(
repo_id=repo["repo_id"],
filename=model["filename"],
n_ctx=2048,
n_threads=4,
n_gpu_layers=-1,
verbose=False
)
progress(1.0, desc="Complete!")
return f"✅ Model '{model_name}' loaded successfully!"
except Exception as e:
llm = None
return f"❌ Error loading model: {str(e)}"
def retrieve_context(query, k=None):
if index is None:
return "Error: Search index not initialized."
# Use k from prompt config if not specified
if k is None:
k = prompt_config["rag"]["num_retrieved_chunks"]
query_embedding = embeddings.encode(query).astype('float32').reshape(1, -1)
faiss.normalize_L2(query_embedding)
distances, indices = index.search(query_embedding, k)
retrieved_texts = []
for i in indices[0]:
if 0 <= i < len(texts):
retrieved_texts.append(texts[i])
# Use separator from prompt config
separator = prompt_config["rag"]["context_separator"]
print(f"Retrieved {len(retrieved_texts)} context chunks for the query.")
print("Similarities:", distances)
return separator.join(retrieved_texts)
def respond(message, history):
"""
Generator function for streaming response.
gr.ChatInterface passes 'message' and 'history' automatically.
"""
if llm is None:
yield "System Error: Models failed to load. Check console logs."
return
# Retrieve context using config settings
context = retrieve_context(message)
# Build prompt from template
prompt = prompt_config["template"].format(
context=context,
question=message
)
# Get generation parameters from config
gen_params = prompt_config["generation"]
output = llm(
prompt,
max_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
stop=gen_params["stop_sequences"],
stream=True
)
partial_message = ""
for chunk in output:
text_chunk = chunk["choices"][0]["text"]
partial_message += text_chunk
yield partial_message
with gr.Blocks(title="Hopsworks RAG ChatBot") as demo:
gr.Markdown("<h1 style='text-align: center; color: #1EB382'>Hopsworks ChatBot</h1>")
# Model Selection Section
with gr.Row():
repo_dropdown = gr.Dropdown(
choices=[r["name"] for r in models_config["repositories"]],
label="Select Repository",
value=models_config["repositories"][0]["name"],
scale=2
)
model_dropdown = gr.Dropdown(
choices=[m["name"] for m in models_config["repositories"][0]["models"]],
label="Select Model",
value=models_config["repositories"][0]["models"][0]["name"],
scale=2
)
load_button = gr.Button("Load Model", variant="primary", scale=1)
status_box = gr.Textbox(
label="Status",
value="⚠️ Please select a repository and model, then click 'Load Model'",
interactive=False
)
# Model info display
model_info = gr.Markdown("")
# Chat Interface
chat_interface = gr.ChatInterface(
fn=respond,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder="Ask a question about your documents...", container=False, scale=7),
examples=["What is the main topic of the documents?", "Summarize the key points."],
cache_examples=False,
)
# Function to update model dropdown when repository changes
def update_model_choices(repo_name):
repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
if repo and repo["models"]:
model_choices = [m["name"] for m in repo["models"]]
return gr.Dropdown(choices=model_choices, value=model_choices[0])
return gr.Dropdown(choices=[], value=None)
# Function to update model info display
def update_model_info(repo_name, model_name):
repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
if not repo:
return ""
model = next((m for m in repo["models"] if m["name"] == model_name), None)
if model:
return f"**{model['name']}**\n\n{model['description']}\n\n Repository: `{repo['repo_id']}`\n\n File: `{model['filename']}`"
return ""
# Event handlers
repo_dropdown.change(update_model_choices, inputs=[repo_dropdown], outputs=[model_dropdown])
repo_dropdown.change(update_model_info, inputs=[repo_dropdown, model_dropdown], outputs=[model_info])
model_dropdown.change(update_model_info, inputs=[repo_dropdown, model_dropdown], outputs=[model_info])
load_button.click(load_model, inputs=[repo_dropdown, model_dropdown], outputs=[status_box])
# Load default model info on startup
demo.load(
lambda: update_model_info(
models_config["repositories"][0]["name"],
models_config["repositories"][0]["models"][0]["name"]
),
outputs=[model_info]
)
if __name__ == "__main__":
demo.launch(share=True)