Kiy-K commited on
Commit
36b6bbe
·
verified ·
1 Parent(s): 00f0f74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -50
app.py CHANGED
@@ -1,70 +1,225 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
 
 
 
 
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  chatbot = gr.ChatInterface(
47
  respond,
48
  type="messages",
49
  additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
  ],
61
  )
62
 
63
- with gr.Blocks() as demo:
 
64
  with gr.Sidebar():
65
  gr.LoginButton()
 
 
 
66
  chatbot.render()
67
 
68
-
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ # app.py — full version with memory + web search + datasets
2
+
3
+ import os
4
+ import json
5
+ import threading
6
  import gradio as gr
7
+ from huggingface_hub import InferenceClient, snapshot_download
8
+ from datasets import load_dataset
9
+ from duckduckgo_search import DDGS
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ client = InferenceClient(
12
+ provider="cerebras",
13
+ api_key=os.environ["csk-933e3whtcvhjtfchfmmk4ncdtc86jp26v4vkn9rd5yk6ny5c"],
14
+ )
15
+
16
+ # ---------------- CONFIG ----------------
17
+ MODEL_ID = "openai/gpt-oss-120b" # or granite
18
+ DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
19
+ os.makedirs(DATA_DIR, exist_ok=True)
20
+
21
+ SHORT_TERM_LIMIT = 10
22
+ SUMMARY_MAX_TOKENS = 150
23
+ MEMORY_LOCK = threading.Lock()
24
+
25
+ # ---------------- dataset loading ----------------
26
+ # ⚠️ Heavy startup, comment out if running on free HF Space
27
+ folder = snapshot_download(
28
+ "HuggingFaceFW/fineweb",
29
+ repo_type="dataset",
30
+ local_dir="./fineweb/",
31
+ allow_patterns="sample/10BT/*",
32
+ )
33
+ ds1 = load_dataset("HuggingFaceH4/ultrachat_200k")
34
+ ds2 = load_dataset("Anthropic/hh-rlhf")
35
+
36
+ # ---------------- helpers: memory ----------------
37
+ def get_user_id(hf_token: gr.OAuthToken | None):
38
+ if hf_token and getattr(hf_token, "token", None):
39
+ return "user_" + hf_token.token[:12]
40
+ return "anon"
41
+
42
+ def memory_file_path(user_id: str):
43
+ return os.path.join(DATA_DIR, f"memory_{user_id}.json")
44
+
45
+ def load_memory(user_id: str):
46
+ p = memory_file_path(user_id)
47
+ if os.path.exists(p):
48
+ try:
49
+ with open(p, "r", encoding="utf-8") as f:
50
+ mem = json.load(f)
51
+ if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem:
52
+ return mem
53
+ except Exception as e:
54
+ print("load_memory error:", e)
55
+ return {"short_term": [], "long_term": ""}
56
+
57
+ def save_memory(user_id: str, memory: dict):
58
+ p = memory_file_path(user_id)
59
+ try:
60
+ with MEMORY_LOCK:
61
+ with open(p, "w", encoding="utf-8") as f:
62
+ json.dump(memory, f, ensure_ascii=False, indent=2)
63
+ except Exception as e:
64
+ print("save_memory error:", e)
65
+
66
+ # ---------------- normalize history ----------------
67
+ def normalize_history(history):
68
+ out = []
69
+ if not history: return out
70
+ for turn in history:
71
+ if isinstance(turn, dict) and "role" in turn and "content" in turn:
72
+ out.append({"role": turn["role"], "content": str(turn["content"])})
73
+ elif isinstance(turn, (list, tuple)) and len(turn) == 2:
74
+ user_msg, assistant_msg = turn
75
+ out.append({"role": "user", "content": str(user_msg)})
76
+ out.append({"role": "assistant", "content": str(assistant_msg)})
77
+ elif isinstance(turn, str):
78
+ out.append({"role": "user", "content": turn})
79
+ return out
80
+
81
+ # ---------------- sync completion ----------------
82
+ def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9):
83
+ try:
84
+ resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False)
85
+ except Exception as e:
86
+ print("sync chat_completion error:", e)
87
+ return ""
88
 
89
+ try:
90
+ choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None)
91
+ if choices:
92
+ c0 = choices[0]
93
+ msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
94
+ if isinstance(msg, dict):
95
+ return msg.get("content", "")
96
+ return getattr(msg, "content", "") or str(msg or "")
97
+ except Exception:
98
+ pass
99
+ return ""
100
 
101
+ # ---------------- web search ----------------
102
+ def web_search(query, num_results=3):
103
+ try:
104
+ with DDGS() as ddgs:
105
+ results = list(ddgs.text(query, max_results=num_results))
106
+ search_context = "🔍 Web Search Results:\n\n"
107
+ for i, r in enumerate(results, 1):
108
+ title = r.get("title", "")[:200]
109
+ body = r.get("body", "")[:200].replace("\n", " ")
110
+ href = r.get("href", "")
111
+ search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
112
+ return search_context
113
+ except Exception as e:
114
+ return f"❌ Search error: {str(e)}"
115
+
116
+ # ---------------- summarization ----------------
117
+ def summarize_old_messages(client: InferenceClient, old_messages):
118
+ text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages])
119
+ system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."}
120
+ user = {"role": "user", "content": text}
121
+ return _get_chat_response_sync(client, [system, user])
122
+
123
+ # ---------------- memory tools ----------------
124
+ def show_memory(hf_token: gr.OAuthToken | None = None):
125
+ user = get_user_id(hf_token)
126
+ p = memory_file_path(user)
127
+ if not os.path.exists(p):
128
+ return "ℹ️ No memory file found for user: " + user
129
+ with open(p, "r", encoding="utf-8") as f:
130
+ return f.read()
131
+
132
+ def clear_memory(hf_token: gr.OAuthToken | None = None):
133
+ user = get_user_id(hf_token)
134
+ p = memory_file_path(user)
135
+ if os.path.exists(p):
136
+ os.remove(p)
137
+ return f"✅ Memory cleared for {user}"
138
+ return "ℹ️ No memory to clear."
139
+
140
+ # ---------------- main chat ----------------
141
+ def respond(message, history: list, system_message, max_tokens, temperature, top_p,
142
+ enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None):
143
+
144
+ client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
145
+ user_id = get_user_id(hf_token)
146
+ memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""}
147
+
148
+ session_history = normalize_history(history)
149
+ combined = memory.get("short_term", []) + session_history
150
+
151
+ if len(combined) > SHORT_TERM_LIMIT:
152
+ to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT]
153
+ summary = summarize_old_messages(client, to_summarize)
154
+ if summary:
155
+ memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip()
156
+ combined = combined[-SHORT_TERM_LIMIT:]
157
+
158
+ combined.append({"role": "user", "content": message})
159
+ memory["short_term"] = combined
160
+ if enable_persistent_memory:
161
+ save_memory(user_id, memory)
162
+
163
+ messages = [{"role": "system", "content": system_message}]
164
+ if memory.get("long_term"):
165
+ messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]})
166
+ messages.extend(memory["short_term"])
167
+
168
+ if enable_search and any(k in message.lower() for k in ["search", "google", "tin tức", "news", "what is"]):
169
+ sr = web_search(message)
170
+ messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"})
171
 
172
  response = ""
173
+ try:
174
+ for chunk in client.chat_completion(messages, max_tokens=int(max_tokens),
175
+ stream=True, temperature=float(temperature), top_p=float(top_p)):
176
+ choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
177
+ if not choices: continue
178
+ c0 = choices[0]
179
+ delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None)
180
+ token = None
181
+ if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)):
182
+ token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
183
+ else:
184
+ msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
185
+ if isinstance(msg, dict):
186
+ token = msg.get("content", "")
187
+ else:
188
+ token = getattr(msg, "content", None) or str(msg or "")
189
+ if token:
190
+ response += token
191
+ yield response
192
+ except Exception as e:
193
+ yield f"⚠️ Inference error: {e}"
194
+ return
195
+
196
+ memory["short_term"].append({"role": "assistant", "content": response})
197
+ memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
198
+ if enable_persistent_memory:
199
+ save_memory(user_id, memory)
200
 
201
+ # ---------------- Gradio UI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  chatbot = gr.ChatInterface(
203
  respond,
204
  type="messages",
205
  additional_inputs=[
206
+ gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
207
+ gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
208
+ gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
209
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
210
+ gr.Checkbox(value=True, label="Enable Web Search 🔍"),
211
+ gr.Checkbox(value=True, label="Enable Persistent Memory"),
 
 
 
 
212
  ],
213
  )
214
 
215
+ with gr.Blocks(title="AI Chatbot (full version)") as demo:
216
+ gr.Markdown("# 🤖 AI Chatbot with Memory + Web Search + Datasets")
217
  with gr.Sidebar():
218
  gr.LoginButton()
219
+ gr.Markdown("### Memory Tools")
220
+ gr.Button("👀 Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory"))
221
+ gr.Button("🗑️ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status"))
222
  chatbot.render()
223
 
 
224
  if __name__ == "__main__":
225
  demo.launch()