Bman21 commited on
Commit
6dda441
·
verified ·
1 Parent(s): a03bd33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
+ import faiss
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ import pickle
8
+
9
+ # --- Configuration ---
10
+ MODEL_NAME = "openai/gpt-oss-20b"
11
+ SECURE_HF_TOKEN = os.environ.get("HF_TOKEN")
12
+
13
+ if not SECURE_HF_TOKEN:
14
+ raise ValueError("HF_TOKEN environment variable not set. Add a Secret in Space settings.")
15
+
16
+ client = InferenceClient(token=SECURE_HF_TOKEN, model=MODEL_NAME)
17
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
18
+
19
+ # --- Notes folder (TXT training files) ---
20
+ notes_folder = "notes" # <== create this folder in repo and upload TXT files inside
21
+ os.makedirs(notes_folder, exist_ok=True)
22
+ cache_file = os.path.join(notes_folder, "embeddings_cache.pkl")
23
+
24
+ chunks, sources = [], []
25
+
26
+ # --- Load from cache if exists ---
27
+ if os.path.exists(cache_file):
28
+ with open(cache_file, "rb") as f:
29
+ chunks, sources, embeddings = pickle.load(f)
30
+ dim = embeddings.shape[1]
31
+ index = faiss.IndexFlatL2(dim)
32
+ index.add(np.array(embeddings).astype("float32"))
33
+ else:
34
+ # --- Read all TXT files from notes/ ---
35
+ for file in os.listdir(notes_folder):
36
+ if file.endswith(".txt"):
37
+ subject = os.path.splitext(file)[0]
38
+ with open(os.path.join(notes_folder, file), "r", encoding="utf-8", errors="ignore") as f:
39
+ text = f.read()
40
+
41
+ # Split into chunks
42
+ file_chunks = [text[i:i+500] for i in range(0, len(text), 500)]
43
+ chunks.extend(file_chunks)
44
+ sources.extend([subject] * len(file_chunks))
45
+
46
+ if chunks:
47
+ embeddings = embedder.encode(chunks)
48
+ dim = embeddings.shape[1]
49
+ index = faiss.IndexFlatL2(dim)
50
+ index.add(np.array(embeddings).astype("float32"))
51
+
52
+ with open(cache_file, "wb") as f:
53
+ pickle.dump((chunks, sources, embeddings), f)
54
+ else:
55
+ index = None
56
+
57
+ # --- Respond function ---
58
+ def respond(message, history: list, system_message, max_tokens, temperature, top_p):
59
+ context = ""
60
+ source_names = set()
61
+
62
+ if index is not None and len(chunks) > 0:
63
+ query_emb = embedder.encode([message])
64
+ query_emb = np.array(query_emb).astype("float32")
65
+ k = min(3, len(chunks))
66
+ D, I = index.search(query_emb, k=k)
67
+ retrieved_chunks = [chunks[i] for i in I[0] if i != -1]
68
+ retrieved_sources = [sources[i] for i in I[0] if i != -1]
69
+
70
+ if retrieved_chunks:
71
+ context = "\n".join(retrieved_chunks)
72
+ source_names.update(retrieved_sources)
73
+
74
+ messages = [{"role": "system", "content": system_message}]
75
+ messages.extend(history)
76
+
77
+ source_text = ""
78
+ if source_names:
79
+ source_text = "Sources: " + ", ".join(sorted(source_names)) + "\n\n"
80
+
81
+ prompt_content = f"{source_text}Answer using the following notes if relevant:\n{context}\n\nQuestion: {message}"
82
+ messages.append({"role": "user", "content": prompt_content})
83
+
84
+ response = ""
85
+ for message_chunk in client.chat_completion(
86
+ messages,
87
+ max_tokens=max_tokens,
88
+ stream=True,
89
+ temperature=temperature,
90
+ top_p=top_p,
91
+ ):
92
+ choices = message_chunk.choices
93
+ token = ""
94
+ if len(choices) and choices[0].delta.content:
95
+ token = choices[0].delta.content
96
+ response += token
97
+ yield response
98
+
99
+ # --- Gradio Chat Interface ---
100
+ chatbot = gr.ChatInterface(
101
+ respond,
102
+ type="messages",
103
+ additional_inputs=[
104
+ gr.Textbox(value="Hey, need help?", label="System message"),
105
+ gr.Slider(1, 5000, value=3000, step=1, label="Max new tokens"),
106
+ gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
107
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
108
+ ],
109
+ )
110
+
111
+ # --- Launch (mobile-friendly, no sidebar) ---
112
+ with gr.Blocks(css=".gradio-container {max-width: 800px; margin:auto;}") as demo:
113
+ gr.Markdown("<h2 style='text-align:center;'>📚 AI Tutor (Trained on Notes)</h2>")
114
+ chatbot.render()
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch()