Anirudh Esthuri commited on
Commit
e91e2b4
Β·
1 Parent(s): bd7679d

Copy all files from Playground - app, gateway_client, llm, model_config, requirements, styles, assets, and config files

Browse files
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # macOS
2
+ .DS_Store
3
+ .AppleDouble
4
+ .LSOverride
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # Virtual environments
30
+ venv/
31
+ env/
32
+ ENV/
33
+ .venv
34
+
35
+ # IDE
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+ *.swo
40
+ *~
41
+
42
+ # Environment variables
43
+ .env
44
+ .env.local
45
+
46
+ # Logs
47
+ *.log
48
+
49
+ # Cache
50
+ .cache/
51
+ .pytest_cache/
52
+ .mypy_cache/
53
+
.streamlit/config.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [server]
2
+ headless = true
3
+ enableCORS = false
4
+ enableXsrfProtection = false
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y git
6
+
7
+ COPY . .
8
+
9
+ RUN pip install --upgrade pip
10
+ RUN pip install -r requirements.txt
11
+
12
+ # HuggingFace sets $PORT, don't override it.
13
+
14
+ CMD ["bash", "-c", "echo Using PORT=$PORT && streamlit run app.py --server.address 0.0.0.0 --server.port $PORT"]
README.md CHANGED
@@ -1,12 +1,17 @@
1
  ---
2
  title: MemMachine Playground
3
- emoji: πŸ“Š
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: docker
 
7
  pinned: false
8
- license: apache-2.0
9
- short_description: MemMachine-Playground
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: MemMachine Playground
3
+ emoji: 🧠
 
 
4
  sdk: docker
5
+ app_port: 7860
6
  pinned: false
 
 
7
  ---
8
 
9
+
10
+ # MemMachine Frontend Playground
11
+
12
+ This is a Streamlit-based UI for interacting with a remote MemMachine backend.
13
+
14
+ - Frontend: Streamlit (runs in this Space)
15
+ - Backend: MemMachine server running on EC2
16
+ - Memory + vector search: Neo4j + Postgres
17
+ - All requests route to your backend via `gateway_client.py`
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import cast
3
+
4
+ import streamlit as st
5
+ from gateway_client import delete_profile, ingest_and_rewrite
6
+ from llm import chat, set_model
7
+ from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER
8
+
9
+
10
+ def rewrite_message(
11
+ msg: str, persona_name: str, show_rationale: bool, skip_rewrite: bool
12
+ ) -> str:
13
+ if skip_rewrite:
14
+ rewritten_msg = msg
15
+ if show_rationale:
16
+ rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
17
+ else:
18
+ try:
19
+ rewritten_msg = ingest_and_rewrite(
20
+ user_id=persona_name, query=msg, model_type=provider
21
+ )
22
+ if show_rationale:
23
+ rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
24
+
25
+ except Exception as e:
26
+ # If backend is unavailable, use original message without rewriting
27
+ st.warning(f"Backend memory server unavailable. Using message without personalization: {e}")
28
+ rewritten_msg = msg
29
+ if show_rationale:
30
+ rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied (backend unavailable).'. Begin your answer on the next line."
31
+ return rewritten_msg
32
+
33
+
34
+ # ──────────────────────────────────────────────────────────────
35
+ # Page setup & CSS
36
+ # ──────────────────────────────────────────────────────────────
37
+ st.set_page_config(page_title="MemMachine Chatbot", layout="wide")
38
+ with open("./styles.css") as f:
39
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
40
+
41
+
42
+ # ──────────────────────────────────────────────────────────────
43
+ # Sidebar
44
+ # ──────────────────────────────────────────────────────────────
45
+ with st.sidebar:
46
+ st.image("./assets/memmachine_logo.png", use_container_width=True)
47
+
48
+ st.markdown("#### Choose Model")
49
+
50
+ model_id = st.selectbox(
51
+ "Choose Model", MODEL_CHOICES, index=0, label_visibility="collapsed"
52
+ )
53
+ provider = MODEL_TO_PROVIDER[model_id]
54
+ set_model(model_id)
55
+
56
+ st.markdown("#### Choose user persona")
57
+ selected_persona = st.selectbox(
58
+ "Choose user persona",
59
+ ["Charlie", "Jing", "Charles", "Control"],
60
+ label_visibility="collapsed",
61
+ )
62
+ custom_persona = st.text_input("Or enter your name", "")
63
+ persona_name = (
64
+ custom_persona.strip() if custom_persona.strip() else selected_persona
65
+ )
66
+
67
+ skip_rewrite = st.checkbox("Skip Rewrite")
68
+ compare_personas = st.checkbox("Compare with Control persona")
69
+ show_rationale = st.checkbox("Show Persona Rationale")
70
+
71
+ st.divider()
72
+ if st.button("Clear chat", use_container_width=True):
73
+ st.session_state.history = []
74
+ st.rerun()
75
+ if st.button("Delete Profile", use_container_width=True):
76
+ success = delete_profile(persona_name)
77
+ st.session_state.history = []
78
+ if success:
79
+ st.success(f"Profile for '{persona_name}' deleted.")
80
+ else:
81
+ st.error(f"Failed to delete profile for '{persona_name}'.")
82
+ st.divider()
83
+
84
+ # ──────────────────────────────────────────────────────────────
85
+ # Session state
86
+ # ──────────────────────────────────────────────────────────────
87
+ if "history" not in st.session_state:
88
+ st.session_state.history = cast(list[dict], [])
89
+
90
+
91
+ # ──────────────────────────────────────────────────────────────
92
+ # Enforce alternating roles
93
+ # ──────────────────────────────────────────────────────────────
94
+ def clean_history(history: list[dict], persona: str) -> list[dict]:
95
+ out = []
96
+ for turn in history:
97
+ if turn.get("role") == "user":
98
+ out.append({"role": "user", "content": turn["content"]})
99
+ elif turn.get("role") == "assistant" and turn.get("persona") == persona:
100
+ out.append({"role": "assistant", "content": turn["content"]})
101
+ cleaned = []
102
+ last_role = None
103
+ for msg in out:
104
+ if msg["role"] != last_role:
105
+ cleaned.append(msg)
106
+ last_role = msg["role"]
107
+ return cleaned
108
+
109
+
110
+ def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]:
111
+ if msgs and msgs[-1]["role"] == "user":
112
+ msgs[-1] = {"role": "user", "content": new_user_msg}
113
+ else:
114
+ msgs.append({"role": "user", "content": new_user_msg})
115
+ return msgs
116
+
117
+
118
+ # ──────────────────────────────────────────────────────────────
119
+ # Title
120
+ # ──────────────────────────────────────────────────────────────
121
+ st.title("MemMachine Chatbot")
122
+
123
+ # ──────────────────────────────────────────────────────────────
124
+ # Chat logic
125
+ # ──────────────────────────────────────────────────────────────
126
+ msg = st.chat_input("Type your message…")
127
+ if msg:
128
+ st.session_state.history.append({"role": "user", "content": msg})
129
+ # rewritten_msg = "Use the persona profile to personalize your naswer only when applicable.\n"
130
+ if compare_personas:
131
+ all_answers = {}
132
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale, False)
133
+ msgs = clean_history(st.session_state.history, persona_name)
134
+ msgs = append_user_turn(msgs, rewritten_msg)
135
+ txt, lat, tok, tps = chat(msgs, persona_name)
136
+ all_answers[persona_name] = txt
137
+
138
+ rewritten_msg_control = rewrite_message(msg, "Control", show_rationale, True)
139
+ msgs_control = clean_history(st.session_state.history, "Control")
140
+ msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
141
+ txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
142
+ all_answers["Control"] = txt_control
143
+
144
+ st.session_state.history.append(
145
+ {"role": "assistant_all", "axis": "role", "content": all_answers}
146
+ )
147
+ else:
148
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale, skip_rewrite)
149
+ msgs = clean_history(st.session_state.history, persona_name)
150
+ msgs = append_user_turn(msgs, rewritten_msg)
151
+ txt, lat, tok, tps = chat(
152
+ msgs, "Arnold" if persona_name == "Control" else persona_name
153
+ )
154
+ st.session_state.history.append(
155
+ {"role": "assistant", "persona": persona_name, "content": txt}
156
+ )
157
+
158
+ # ──────────────────────────────────────────────────────────────
159
+ # Chat history display
160
+ # ──────────────────────────────────────────────────────────────
161
+ for turn in st.session_state.history:
162
+ if turn.get("role") == "user":
163
+ st.chat_message("user").write(turn["content"])
164
+ elif turn.get("role") == "assistant":
165
+ st.chat_message("assistant").write(turn["content"])
166
+ elif turn.get("role") == "assistant_all":
167
+ content_items = list(turn["content"].items())
168
+ if len(content_items) >= 2:
169
+ cols = st.columns([1, 0.03, 1])
170
+ persona_label, persona_response = content_items[0]
171
+ control_label, control_response = content_items[1]
172
+ with cols[0]:
173
+ st.markdown(f"**{persona_label}**")
174
+ st.markdown(
175
+ f'<div class="answer">{persona_response}</div>',
176
+ unsafe_allow_html=True,
177
+ )
178
+ with cols[1]:
179
+ st.markdown(
180
+ '<div class="vertical-divider"></div>', unsafe_allow_html=True
181
+ )
182
+ with cols[2]:
183
+ st.markdown(f"**{control_label}**")
184
+ st.markdown(
185
+ f'<div class="answer">{control_response}</div>',
186
+ unsafe_allow_html=True,
187
+ )
188
+ else:
189
+ for label, response in content_items:
190
+ st.markdown(f"**{label}**")
191
+ st.markdown(
192
+ f'<div class="answer">{response}</div>', unsafe_allow_html=True
193
+ )
assets/memmachine_logo.png ADDED
assets/memverge_logo.png ADDED
gateway_client.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+
4
+ import requests
5
+
6
+ # Backend server URL - can be set via environment variable
7
+ # For Hugging Face Spaces: Set MEMORY_SERVER_URL in Space settings (Repository secrets)
8
+ # For local development: Set MEMORY_SERVER_URL in your .env file
9
+ # Default: http://3.232.95.65:8080 (MemMachine backend)
10
+ EXAMPLE_SERVER_PORT = os.getenv("MEMORY_SERVER_URL")
11
+
12
+
13
+
14
+ def ingest_and_rewrite(user_id: str, query: str, model_type: str = "openai") -> str:
15
+ """Pass a raw user message through the memory server and get context-aware response."""
16
+ print("entered ingest_and_rewrite")
17
+
18
+ # First, store the message in memory
19
+ session_data = {
20
+ "group_id": user_id,
21
+ "agent_id": ["assistant"],
22
+ "user_id": [user_id],
23
+ "session_id": f"session_{user_id}",
24
+ }
25
+ episode_data = {
26
+ "session": session_data,
27
+ "producer": user_id,
28
+ "produced_for": "assistant",
29
+ "episode_content": query,
30
+ "episode_type": "message",
31
+ "metadata": {
32
+ "speaker": user_id,
33
+ "timestamp": datetime.now().isoformat(),
34
+ "type": "message",
35
+ },
36
+ }
37
+
38
+ # Store the episode
39
+ store_resp = requests.post(
40
+ f"{EXAMPLE_SERVER_PORT}/memory",
41
+ json=episode_data,
42
+ timeout=1000,
43
+ )
44
+ store_resp.raise_for_status()
45
+
46
+ # Then search for relevant context
47
+ search_data = {
48
+ "session": session_data,
49
+ "query": query,
50
+ "limit": 5,
51
+ "filter": {"producer_id": user_id},
52
+ }
53
+
54
+ search_resp = requests.post(
55
+ f"{EXAMPLE_SERVER_PORT}/memory/search",
56
+ json=search_data,
57
+ timeout=1000,
58
+ )
59
+ search_resp.raise_for_status()
60
+
61
+ search_results = search_resp.json()
62
+ content = search_results.get("content", {})
63
+ episodic_memory = content.get("episodic_memory", [])
64
+ profile_memory = content.get("profile_memory", [])
65
+
66
+ # Format the response similar to example_server.py
67
+ if profile_memory and episodic_memory:
68
+ profile_str = "\n".join([str(p) for p in profile_memory]) if isinstance(profile_memory, list) else str(profile_memory)
69
+ context_str = "\n".join([str(c) for c in episodic_memory]) if isinstance(episodic_memory, list) else str(episodic_memory)
70
+ return f"Profile: {profile_str}\n\nContext: {context_str}\n\nQuery: {query}"
71
+ elif profile_memory:
72
+ profile_str = "\n".join([str(p) for p in profile_memory]) if isinstance(profile_memory, list) else str(profile_memory)
73
+ return f"Profile: {profile_str}\n\nQuery: {query}"
74
+ elif episodic_memory:
75
+ context_str = "\n".join([str(c) for c in episodic_memory]) if isinstance(episodic_memory, list) else str(episodic_memory)
76
+ return f"Context: {context_str}\n\nQuery: {query}"
77
+ else:
78
+ return f"Message ingested successfully. No relevant context found yet.\n\nQuery: {query}"
79
+
80
+
81
+ def add_session_message(user_id: str, msg: str) -> None:
82
+ """Add a raw message into memory via memory server."""
83
+ session_data = {
84
+ "group_id": user_id,
85
+ "agent_id": ["assistant"],
86
+ "user_id": [user_id],
87
+ "session_id": f"session_{user_id}",
88
+ }
89
+ episode_data = {
90
+ "session": session_data,
91
+ "producer": user_id,
92
+ "produced_for": "assistant",
93
+ "episode_content": msg,
94
+ "episode_type": "message",
95
+ "metadata": {
96
+ "speaker": user_id,
97
+ "timestamp": datetime.now().isoformat(),
98
+ "type": "message",
99
+ },
100
+ }
101
+ requests.post(
102
+ f"{EXAMPLE_SERVER_PORT}/memory",
103
+ json=episode_data,
104
+ timeout=5,
105
+ )
106
+
107
+
108
+ def create_persona_query(user_id: str, query: str) -> str:
109
+ """Create a persona-aware query by searching memory context via memory server."""
110
+ session_data = {
111
+ "group_id": user_id,
112
+ "agent_id": ["assistant"],
113
+ "user_id": [user_id],
114
+ "session_id": f"session_{user_id}",
115
+ }
116
+ search_data = {
117
+ "session": session_data,
118
+ "query": query,
119
+ "limit": 5,
120
+ "filter": {"producer_id": user_id},
121
+ }
122
+
123
+ resp = requests.post(
124
+ f"{EXAMPLE_SERVER_PORT}/memory/search",
125
+ json=search_data,
126
+ timeout=1000,
127
+ )
128
+ resp.raise_for_status()
129
+
130
+ search_results = resp.json()
131
+ content = search_results.get("content", {})
132
+ profile_memory = content.get("profile_memory", [])
133
+
134
+ if profile_memory:
135
+ profile_str = "\n".join([str(p) for p in profile_memory]) if isinstance(profile_memory, list) else str(profile_memory)
136
+ return f"Based on your profile: {profile_str}\n\nQuery: {query}"
137
+ else:
138
+ return f"Query: {query}"
139
+
140
+
141
+ def add_new_session_message(user_id: str, msg: str) -> None:
142
+ """Alias for add_session_message for backward compatibility."""
143
+ add_session_message(user_id, msg)
144
+
145
+
146
+ def delete_profile(user_id: str) -> bool:
147
+ """Delete all memory for the given user_id via the CRM server."""
148
+ # NOT IMPLEMENTED
149
+ return False
llm.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+
5
+ import boto3
6
+ import openai
7
+ from dotenv import load_dotenv
8
+ from model_config import MODEL_TO_PROVIDER
9
+
10
+ # ──────────────────────────────────────────────────────────────
11
+ # Load environment variables
12
+ load_dotenv()
13
+ # ──────────────────────────────────────────────────────────────
14
+
15
+ # ──────────────────────────────────────────────────────────────
16
+ # Configuration
17
+ # ──────────────────────────────────────────────────────────────
18
+ MODEL_STRING = "gpt-4.1-mini" # we default on gpt-4.1-mini
19
+ api_key = os.getenv("MODEL_API_KEY")
20
+ client = openai.OpenAI(api_key=api_key)
21
+ bedrock_runtime = boto3.client("bedrock-runtime", region_name="us-west-2")
22
+
23
+
24
+ # ──────────────────────────────────────────────────────────────
25
+ # Model switcher
26
+ # ──────────────────────────────────────────────────────────────
27
+ def set_model(model_id: str) -> None:
28
+ global MODEL_STRING
29
+ MODEL_STRING = model_id
30
+ print(f"Model changed to: {model_id}")
31
+
32
+
33
+ def set_provider(provider: str) -> None:
34
+ global PROVIDER
35
+
36
+
37
+ # ──────────────────────────────────────────────────────────────
38
+ # High-level Chat wrapper
39
+ # ──────────────────────────────────────────────────────────────
40
+ def chat(messages, persona):
41
+ provider = MODEL_TO_PROVIDER[MODEL_STRING]
42
+
43
+ if provider == "openai":
44
+ print("Using openai: ", MODEL_STRING)
45
+ system_prompt = None
46
+ if messages and messages[0].get("role") == "system":
47
+ system_prompt = messages[0]["content"]
48
+ messages = messages[1:]
49
+
50
+ t0 = time.time()
51
+ out = client.responses.create(
52
+ model=MODEL_STRING,
53
+ instructions=system_prompt,
54
+ input=messages, # messages=messages
55
+ max_output_tokens=500, # max_tokens=500,
56
+ temperature=0.5,
57
+ store=False, # keeps call stateless
58
+ )
59
+
60
+ dt = time.time() - t0
61
+
62
+ text = out.output_text.strip() # out.choices[0].message.content.strip()
63
+
64
+ tok_out = out.usage.output_tokens
65
+ tok_in = out.usage.input_tokens
66
+ total_tok = (
67
+ tok_out + tok_in
68
+ if tok_out is not None and tok_in is not None
69
+ else len(text.split())
70
+ )
71
+
72
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
73
+ elif provider == "anthropic":
74
+ print("Using anthropic: ", MODEL_STRING)
75
+ t0 = time.time()
76
+
77
+ claude_messages = [
78
+ {"role": m["role"], "content": m["content"]} for m in messages
79
+ ]
80
+
81
+ response = bedrock_runtime.invoke_model(
82
+ modelId=MODEL_STRING,
83
+ contentType="application/json",
84
+ accept="application/json",
85
+ body=json.dumps(
86
+ {
87
+ "anthropic_version": "bedrock-2023-05-31",
88
+ "messages": claude_messages,
89
+ "max_tokens": 500,
90
+ "temperature": 0.5,
91
+ }
92
+ ),
93
+ )
94
+
95
+ dt = time.time() - t0
96
+ body = json.loads(response["body"].read())
97
+
98
+ text = "".join(
99
+ part["text"] for part in body["content"] if part["type"] == "text"
100
+ ).strip()
101
+ total_tok = len(text.split())
102
+
103
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
104
+ elif provider == "deepseek":
105
+ print("Using deepseek: ", MODEL_STRING)
106
+ t0 = time.time()
107
+
108
+ prompt = messages[-1]["content"]
109
+
110
+ formatted_prompt = (
111
+ f"<|begin▁of▁sentence|><|User|>{prompt}<|Assistant|><think>\n"
112
+ )
113
+
114
+ response = bedrock_runtime.invoke_model(
115
+ modelId=MODEL_STRING,
116
+ contentType="application/json",
117
+ accept="application/json",
118
+ body=json.dumps(
119
+ {
120
+ "prompt": formatted_prompt,
121
+ "max_tokens": 500,
122
+ "temperature": 0.5,
123
+ "top_p": 0.9,
124
+ }
125
+ ),
126
+ )
127
+
128
+ dt = time.time() - t0
129
+ body = json.loads(response["body"].read())
130
+
131
+ text = body["choices"][0]["text"].strip()
132
+ total_tok = len(text.split())
133
+
134
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
135
+ elif provider == "meta":
136
+ print("Using meta (LLaMA): ", MODEL_STRING)
137
+ t0 = time.time()
138
+
139
+ prompt = messages[-1]["content"]
140
+
141
+ # Format prompt in LLaMA-style instruction format
142
+ formatted_prompt = (
143
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n"
144
+ + prompt.strip()
145
+ + "\n<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
146
+ )
147
+
148
+ response = bedrock_runtime.invoke_model(
149
+ modelId=MODEL_STRING,
150
+ contentType="application/json",
151
+ accept="application/json",
152
+ body=json.dumps(
153
+ {"prompt": formatted_prompt, "max_gen_len": 512, "temperature": 0.5}
154
+ ),
155
+ )
156
+
157
+ dt = time.time() - t0
158
+ body = json.loads(response["body"].read())
159
+ text = body.get("generation", "").strip()
160
+ total_tok = len(text.split())
161
+
162
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
163
+ elif provider == "mistral":
164
+ print("Using mistral: ", MODEL_STRING)
165
+ t0 = time.time()
166
+
167
+ prompt = messages[-1]["content"]
168
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
169
+
170
+ response = bedrock_runtime.invoke_model(
171
+ modelId=MODEL_STRING,
172
+ contentType="application/json",
173
+ accept="application/json",
174
+ body=json.dumps(
175
+ {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
176
+ ),
177
+ )
178
+
179
+ dt = time.time() - t0
180
+ body = json.loads(response["body"].read())
181
+
182
+ text = body["outputs"][0]["text"].strip()
183
+ total_tok = len(text.split())
184
+
185
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
186
+
187
+
188
+ # ──────────────────────────────────────────────────────────────
189
+ # Diagnostics / CLI test
190
+ # ──────────────────────────────────────────────────────────────
191
+ def check_credentials():
192
+ required = ["MODEL_API_KEY"]
193
+ missing = [var for var in required if not os.getenv(var)]
194
+ if missing:
195
+ print(f"Missing environment variables: {missing}")
196
+ return False
197
+ return True
198
+
199
+
200
+ def test_chat():
201
+ print("Testing chat...")
202
+ try:
203
+ test_messages = [
204
+ {
205
+ "role": "user",
206
+ "content": "Hello! Please respond with just 'Test successful'.",
207
+ }
208
+ ]
209
+ text, latency, tokens, tps = chat(test_messages)
210
+ print(f"Test passed! {text} {latency:.2f}s {tokens} ⚑ {tps:.1f} tps")
211
+ except Exception as e:
212
+ print(f"Test failed: {e}")
213
+
214
+
215
+ if __name__ == "__main__":
216
+ print("running diagnostics")
217
+ if check_credentials():
218
+ test_chat()
219
+ print("\nDone.")
model_config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROVIDER_MODEL_MAP = {
2
+ "openai": ["gpt-4.1-mini"],
3
+ "anthropic": [
4
+ "anthropic.claude-3-sonnet-20240229-v1:0",
5
+ "anthropic.claude-3-5-haiku-20241022-v1:0",
6
+ ],
7
+ "deepseek": ["us.deepseek.r1-v1:0"],
8
+ "meta": ["meta.llama3-8b-instruct-v1:0", "meta.llama3-70b-instruct-v1:0"],
9
+ "mistral": [
10
+ "mistral.mixtral-8x7b-instruct-v0:1",
11
+ "mistral.mistral-7b-instruct-v0:2",
12
+ ],
13
+ }
14
+ # "meta.llama4-maverick-17b-instruct-v1:0" (not currently working)
15
+
16
+ MODEL_TO_PROVIDER = {
17
+ model: provider
18
+ for provider, models in PROVIDER_MODEL_MAP.items()
19
+ for model in models
20
+ }
21
+
22
+ MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair
2
+ pandas
3
+ streamlit
4
+ requests
5
+ python-dotenv
6
+ websocket-client
7
+ requests
8
+ openai
9
+ anthropic
10
+ tiktoken
11
+ pydantic
12
+ boto3
styles.css ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* -- Sidebar width & padding -- */
2
+ section[data-testid="stSidebar"] { width: 230px !important; }
3
+ section[data-testid="stSidebarContent"] { width: 230px !important;
4
+ padding: 0.75rem; }
5
+
6
+ /* -- Title size -- */
7
+ h1 { font-size: 2.1rem !important; margin-bottom: 1rem; }
8
+
9
+ /* -- Ensure long links wrap inside comparison columns -- */
10
+ div.answer { white-space: pre-wrap; overflow-wrap: anywhere; }
11
+
12
+ /* Tighten spacing between comparison columns */
13
+ div[data-testid="column"] {
14
+ padding-left: 0.25rem !important;
15
+ padding-right: 0.25rem !important;
16
+ margin-left: 0 !important;
17
+ margin-right: 0 !important;
18
+ flex-grow: 1;
19
+ }
20
+
21
+ /* Align vertical divider better */
22
+ .vertical-divider {
23
+ height: 100%;
24
+ border-left: 1px solid #ccc;
25
+ margin: 0 0.4rem;
26
+ }