Kiy-K commited on
Commit
a4cec46
Β·
verified Β·
1 Parent(s): 3c28fa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -126
app.py CHANGED
@@ -1,16 +1,14 @@
1
- # app.py β€” full version with memory + web search + datasets
2
-
3
  import os
4
  import json
5
  import threading
6
  import gradio as gr
7
- from huggingface_hub import InferenceClient, snapshot_download
8
  from datasets import load_dataset
9
  from duckduckgo_search import DDGS
10
 
11
-
12
  # ---------------- CONFIG ----------------
13
- MODEL_ID = "openai/gpt-oss-120b" # or granite
14
  DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
15
  os.makedirs(DATA_DIR, exist_ok=True)
16
 
@@ -18,27 +16,71 @@ SHORT_TERM_LIMIT = 10
18
  SUMMARY_MAX_TOKENS = 150
19
  MEMORY_LOCK = threading.Lock()
20
 
21
- # ---------------- dataset loading ----------------
22
- # ⚠️ Heavy startup, comment out if running on free HF Space
23
- folder = snapshot_download(
24
- "HuggingFaceFW/fineweb",
25
- repo_type="dataset",
26
- local_dir="./fineweb/",
27
- allow_patterns="sample/10BT/*",
28
- )
29
- ds1 = load_dataset("HuggingFaceH4/ultrachat_200k")
30
- ds2 = load_dataset("Anthropic/hh-rlhf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # ---------------- helpers: memory ----------------
33
- def get_user_id(hf_token: gr.OAuthToken | None):
34
  if hf_token and getattr(hf_token, "token", None):
35
  return "user_" + hf_token.token[:12]
36
  return "anon"
37
 
38
- def memory_file_path(user_id: str):
39
  return os.path.join(DATA_DIR, f"memory_{user_id}.json")
40
 
41
- def load_memory(user_id: str):
42
  p = memory_file_path(user_id)
43
  if os.path.exists(p):
44
  try:
@@ -50,7 +92,7 @@ def load_memory(user_id: str):
50
  print("load_memory error:", e)
51
  return {"short_term": [], "long_term": ""}
52
 
53
- def save_memory(user_id: str, memory: dict):
54
  p = memory_file_path(user_id)
55
  try:
56
  with MEMORY_LOCK:
@@ -59,10 +101,10 @@ def save_memory(user_id: str, memory: dict):
59
  except Exception as e:
60
  print("save_memory error:", e)
61
 
62
- # ---------------- normalize history ----------------
63
  def normalize_history(history):
64
  out = []
65
- if not history: return out
 
66
  for turn in history:
67
  if isinstance(turn, dict) and "role" in turn and "content" in turn:
68
  out.append({"role": turn["role"], "content": str(turn["content"])})
@@ -70,35 +112,14 @@ def normalize_history(history):
70
  user_msg, assistant_msg = turn
71
  out.append({"role": "user", "content": str(user_msg)})
72
  out.append({"role": "assistant", "content": str(assistant_msg)})
73
- elif isinstance(turn, str):
74
- out.append({"role": "user", "content": turn})
75
  return out
76
 
77
- # ---------------- sync completion ----------------
78
- def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9):
79
- try:
80
- resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False)
81
- except Exception as e:
82
- print("sync chat_completion error:", e)
83
- return ""
84
-
85
- try:
86
- choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None)
87
- if choices:
88
- c0 = choices[0]
89
- msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
90
- if isinstance(msg, dict):
91
- return msg.get("content", "")
92
- return getattr(msg, "content", "") or str(msg or "")
93
- except Exception:
94
- pass
95
- return ""
96
-
97
- # ---------------- web search ----------------
98
  def web_search(query, num_results=3):
99
  try:
100
  with DDGS() as ddgs:
101
  results = list(ddgs.text(query, max_results=num_results))
 
102
  search_context = "πŸ” Web Search Results:\n\n"
103
  for i, r in enumerate(results, 1):
104
  title = r.get("title", "")[:200]
@@ -107,114 +128,131 @@ def web_search(query, num_results=3):
107
  search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
108
  return search_context
109
  except Exception as e:
110
- return f"❌ Search error: {str(e)}"
111
 
112
- # ---------------- summarization ----------------
113
- def summarize_old_messages(client: InferenceClient, old_messages):
114
- text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages])
115
- system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."}
116
- user = {"role": "user", "content": text}
117
- return _get_chat_response_sync(client, [system, user])
118
-
119
- # ---------------- memory tools ----------------
120
- def show_memory(hf_token: gr.OAuthToken | None = None):
121
  user = get_user_id(hf_token)
122
  p = memory_file_path(user)
123
  if not os.path.exists(p):
124
- return "ℹ️ No memory file found for user: " + user
125
  with open(p, "r", encoding="utf-8") as f:
126
  return f.read()
127
 
128
- def clear_memory(hf_token: gr.OAuthToken | None = None):
129
  user = get_user_id(hf_token)
130
  p = memory_file_path(user)
131
  if os.path.exists(p):
132
  os.remove(p)
133
- return f"βœ… Memory cleared for {user}"
134
- return "ℹ️ No memory to clear."
135
-
136
- # ---------------- main chat ----------------
137
- def respond(message, history: list, system_message, max_tokens, temperature, top_p,
138
- enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None):
139
-
140
- client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
141
- user_id = get_user_id(hf_token)
142
- memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""}
143
-
144
- session_history = normalize_history(history)
145
- combined = memory.get("short_term", []) + session_history
146
-
147
- if len(combined) > SHORT_TERM_LIMIT:
148
- to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT]
149
- summary = summarize_old_messages(client, to_summarize)
150
- if summary:
151
- memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip()
152
- combined = combined[-SHORT_TERM_LIMIT:]
153
-
154
- combined.append({"role": "user", "content": message})
155
- memory["short_term"] = combined
156
- if enable_persistent_memory:
157
- save_memory(user_id, memory)
158
-
159
- messages = [{"role": "system", "content": system_message}]
160
- if memory.get("long_term"):
161
- messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]})
162
- messages.extend(memory["short_term"])
163
-
164
- if enable_search and any(k in message.lower() for k in ["search", "google", "tin tα»©c", "news", "what is"]):
165
- sr = web_search(message)
166
- messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"})
167
-
168
- response = ""
169
  try:
170
- for chunk in client.chat_completion(messages, max_tokens=int(max_tokens),
171
- stream=True, temperature=float(temperature), top_p=float(top_p)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
173
- if not choices: continue
174
- c0 = choices[0]
175
- delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None)
176
- token = None
177
- if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)):
178
- token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
179
- else:
180
- msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
181
- if isinstance(msg, dict):
182
- token = msg.get("content", "")
183
- else:
184
- token = getattr(msg, "content", None) or str(msg or "")
185
- if token:
186
- response += token
187
- yield response
188
  except Exception as e:
189
- yield f"⚠️ Inference error: {e}"
190
- return
191
-
192
- memory["short_term"].append({"role": "assistant", "content": response})
193
- memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
194
- if enable_persistent_memory:
195
- save_memory(user_id, memory)
196
 
197
- # ---------------- Gradio UI ----------------
198
  chatbot = gr.ChatInterface(
199
  respond,
200
  type="messages",
201
  additional_inputs=[
202
- gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
203
- gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
204
  gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
205
  gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
206
- gr.Checkbox(value=True, label="Enable Web Search πŸ”"),
207
- gr.Checkbox(value=True, label="Enable Persistent Memory"),
 
208
  ],
209
  )
210
 
211
- with gr.Blocks(title="AI Chatbot (full version)") as demo:
212
- gr.Markdown("# πŸ€– AI Chatbot with Memory + Web Search + Datasets")
 
213
  with gr.Sidebar():
214
  gr.LoginButton()
215
  gr.Markdown("### Memory Tools")
216
- gr.Button("πŸ‘€ Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory"))
217
- gr.Button("πŸ—‘οΈ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status"))
 
 
 
 
 
 
218
  chatbot.render()
219
 
220
  if __name__ == "__main__":
 
1
+ # app.py β€” Fixed version with streaming + memory + web search
 
2
  import os
3
  import json
4
  import threading
5
  import gradio as gr
6
+ from huggingface_hub import InferenceClient
7
  from datasets import load_dataset
8
  from duckduckgo_search import DDGS
9
 
 
10
  # ---------------- CONFIG ----------------
11
+ MODEL_ID = "openai/gpt-oss-120b"
12
  DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
13
  os.makedirs(DATA_DIR, exist_ok=True)
14
 
 
16
  SUMMARY_MAX_TOKENS = 150
17
  MEMORY_LOCK = threading.Lock()
18
 
19
+ # ---------------- SIMPLE STREAMING DATASET ----------------
20
+ # Only load what we actually use to avoid errors
21
+ print("Loading FineWeb in streaming mode...")
22
+ try:
23
+ fineweb_stream = load_dataset(
24
+ "HuggingFaceFW/fineweb",
25
+ split="train",
26
+ streaming=True
27
+ )
28
+ print("βœ… FineWeb streaming loaded")
29
+ except Exception as e:
30
+ print(f"FineWeb loading failed: {e}")
31
+ fineweb_stream = None
32
+
33
+ # Keep other datasets as before for stability
34
+ try:
35
+ ds1 = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:5000]") # Small sample
36
+ ds2 = load_dataset("Anthropic/hh-rlhf", split="train[:5000]") # Small sample
37
+ print("βœ… Other datasets loaded")
38
+ except Exception as e:
39
+ print(f"Dataset loading error: {e}")
40
+ ds1, ds2 = None, None
41
+
42
+ # ---------------- SIMPLE FINEWEB SEARCH ----------------
43
+ def search_fineweb(query, max_search=1000):
44
+ """Simple FineWeb search - safe version"""
45
+ if not fineweb_stream:
46
+ return "FineWeb not available"
47
+
48
+ try:
49
+ query_lower = query.lower()
50
+ found_content = []
51
+ count = 0
52
+
53
+ for sample in fineweb_stream:
54
+ if count >= max_search:
55
+ break
56
+
57
+ text = sample.get('text', '')
58
+ if len(text) > 50 and query_lower in text.lower():
59
+ content = text[:300] + "..." if len(text) > 300 else text
60
+ found_content.append(content)
61
+ if len(found_content) >= 3: # Max 3 results
62
+ break
63
+
64
+ count += 1
65
+
66
+ if found_content:
67
+ return "πŸ“š FineWeb Results:\n\n" + "\n\n---\n\n".join(found_content)
68
+ else:
69
+ return "No relevant FineWeb content found"
70
+
71
+ except Exception as e:
72
+ return f"FineWeb search error: {str(e)}"
73
 
74
+ # ---------------- MEMORY FUNCTIONS (SAME AS BEFORE) ----------------
75
+ def get_user_id(hf_token):
76
  if hf_token and getattr(hf_token, "token", None):
77
  return "user_" + hf_token.token[:12]
78
  return "anon"
79
 
80
+ def memory_file_path(user_id):
81
  return os.path.join(DATA_DIR, f"memory_{user_id}.json")
82
 
83
+ def load_memory(user_id):
84
  p = memory_file_path(user_id)
85
  if os.path.exists(p):
86
  try:
 
92
  print("load_memory error:", e)
93
  return {"short_term": [], "long_term": ""}
94
 
95
+ def save_memory(user_id, memory):
96
  p = memory_file_path(user_id)
97
  try:
98
  with MEMORY_LOCK:
 
101
  except Exception as e:
102
  print("save_memory error:", e)
103
 
 
104
  def normalize_history(history):
105
  out = []
106
+ if not history:
107
+ return out
108
  for turn in history:
109
  if isinstance(turn, dict) and "role" in turn and "content" in turn:
110
  out.append({"role": turn["role"], "content": str(turn["content"])})
 
112
  user_msg, assistant_msg = turn
113
  out.append({"role": "user", "content": str(user_msg)})
114
  out.append({"role": "assistant", "content": str(assistant_msg)})
 
 
115
  return out
116
 
117
+ # ---------------- WEB SEARCH (SAME AS BEFORE) ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def web_search(query, num_results=3):
119
  try:
120
  with DDGS() as ddgs:
121
  results = list(ddgs.text(query, max_results=num_results))
122
+
123
  search_context = "πŸ” Web Search Results:\n\n"
124
  for i, r in enumerate(results, 1):
125
  title = r.get("title", "")[:200]
 
128
  search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
129
  return search_context
130
  except Exception as e:
131
+ return f"Search error: {str(e)}"
132
 
133
+ # ---------------- MEMORY TOOLS ----------------
134
+ def show_memory(hf_token=None):
 
 
 
 
 
 
 
135
  user = get_user_id(hf_token)
136
  p = memory_file_path(user)
137
  if not os.path.exists(p):
138
+ return f"No memory found for {user}"
139
  with open(p, "r", encoding="utf-8") as f:
140
  return f.read()
141
 
142
+ def clear_memory(hf_token=None):
143
  user = get_user_id(hf_token)
144
  p = memory_file_path(user)
145
  if os.path.exists(p):
146
  os.remove(p)
147
+ return f"Memory cleared for {user}"
148
+ return "No memory to clear"
149
+
150
+ # ---------------- MAIN CHAT FUNCTION ----------------
151
+ def respond(message, history, system_message, max_tokens, temperature, top_p,
152
+ enable_web_search, enable_fineweb_search, enable_memory, hf_token=None):
153
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  try:
155
+ client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
156
+ user_id = get_user_id(hf_token)
157
+
158
+ # Memory handling
159
+ memory = load_memory(user_id) if enable_memory else {"short_term": [], "long_term": ""}
160
+ session_history = normalize_history(history)
161
+ combined = memory.get("short_term", []) + session_history
162
+ combined.append({"role": "user", "content": message})
163
+
164
+ # Keep memory manageable
165
+ if len(combined) > SHORT_TERM_LIMIT:
166
+ combined = combined[-SHORT_TERM_LIMIT:]
167
+
168
+ memory["short_term"] = combined
169
+ if enable_memory:
170
+ save_memory(user_id, memory)
171
+
172
+ # Build messages
173
+ messages = [{"role": "system", "content": system_message}]
174
+
175
+ # Add memory context
176
+ if memory.get("long_term"):
177
+ messages.append({"role": "system", "content": f"Memory: {memory['long_term']}"})
178
+
179
+ # Add search results if needed
180
+ search_keywords = ["search", "find", "what is", "tell me about", "news", "latest"]
181
+ should_search = any(keyword in message.lower() for keyword in search_keywords)
182
+
183
+ context_parts = []
184
+
185
+ if enable_web_search and should_search:
186
+ web_results = web_search(message)
187
+ context_parts.append(web_results)
188
+
189
+ if enable_fineweb_search and should_search:
190
+ fineweb_results = search_fineweb(message)
191
+ if "not available" not in fineweb_results and "No relevant" not in fineweb_results:
192
+ context_parts.append(fineweb_results)
193
+
194
+ if context_parts:
195
+ search_context = "\n\n".join(context_parts)
196
+ messages.append({"role": "system", "content": f"Context:\n{search_context}"})
197
+
198
+ messages.extend(memory["short_term"])
199
+
200
+ # Generate response
201
+ response = ""
202
+ for chunk in client.chat_completion(
203
+ messages,
204
+ max_tokens=int(max_tokens),
205
+ stream=True,
206
+ temperature=float(temperature),
207
+ top_p=float(top_p)
208
+ ):
209
  choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
210
+ if choices:
211
+ delta = choices[0].get("delta") if isinstance(choices[0], dict) else getattr(choices[0], "delta", None)
212
+ if delta:
213
+ token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
214
+ if token:
215
+ response += token
216
+ yield response
217
+
218
+ # Save response to memory
219
+ memory["short_term"].append({"role": "assistant", "content": response})
220
+ memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
221
+ if enable_memory:
222
+ save_memory(user_id, memory)
223
+
 
224
  except Exception as e:
225
+ yield f"Error: {str(e)}"
 
 
 
 
 
 
226
 
227
+ # ---------------- GRADIO UI ----------------
228
  chatbot = gr.ChatInterface(
229
  respond,
230
  type="messages",
231
  additional_inputs=[
232
+ gr.Textbox(value="You are a helpful AI assistant with access to web search and knowledge datasets.", label="System message"),
233
+ gr.Slider(1, 2048, value=512, step=1, label="Max tokens"),
234
  gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
235
  gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
236
+ gr.Checkbox(value=True, label="🌐 Web Search"),
237
+ gr.Checkbox(value=True, label="πŸ“š FineWeb Search"),
238
+ gr.Checkbox(value=True, label="🧠 Memory"),
239
  ],
240
  )
241
 
242
+ with gr.Blocks(title="AI Chatbot - Fixed Version") as demo:
243
+ gr.Markdown("# πŸ€– AI Chatbot with Streaming FineWeb + Memory + Web Search")
244
+
245
  with gr.Sidebar():
246
  gr.LoginButton()
247
  gr.Markdown("### Memory Tools")
248
+
249
+ show_btn = gr.Button("πŸ‘€ Show Memory")
250
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Memory")
251
+ memory_display = gr.Textbox(label="Memory Status", lines=5)
252
+
253
+ show_btn.click(show_memory, inputs=None, outputs=memory_display)
254
+ clear_btn.click(clear_memory, inputs=None, outputs=memory_display)
255
+
256
  chatbot.render()
257
 
258
  if __name__ == "__main__":