Spaces:
Runtime error
Runtime error
Commit
·
1e310b7
1
Parent(s):
8b6f89e
docs minus assets
Browse files- .gitignore +4 -1
- Dockerfile +1 -1
- README.md +12 -1
- index.html +1 -1
- parallel_eval/README.md +59 -0
- parallel_eval/game.py +310 -0
- parallel_eval/proctor.py +233 -0
- parallel_eval/requirements.txt +5 -0
- parallel_eval/supernodes.json +19 -0
- src/components/viewer-tab.tsx +69 -5
.gitignore
CHANGED
|
@@ -28,4 +28,7 @@ tmp
|
|
| 28 |
|
| 29 |
qwen3-final-results.json
|
| 30 |
|
| 31 |
-
__pycache__
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
qwen3-final-results.json
|
| 30 |
|
| 31 |
+
__pycache__
|
| 32 |
+
.venv
|
| 33 |
+
proctor_tmp
|
| 34 |
+
wikihop.db
|
Dockerfile
CHANGED
|
@@ -53,7 +53,7 @@ RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
|
| 53 |
curl https://huggingface.co/api/whoami-v2 -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)"
|
| 54 |
|
| 55 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
| 56 |
-
curl -L https://huggingface.co/HuggingFaceTB/simplewiki-pruned-text-350k/
|
| 57 |
|
| 58 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
| 59 |
|
|
|
|
| 53 |
curl https://huggingface.co/api/whoami-v2 -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)"
|
| 54 |
|
| 55 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
| 56 |
+
curl -L https://huggingface.co/datasets/HuggingFaceTB/simplewiki-pruned-text-350k/blob/main/wikihop.db -H "Authorization: Bearer $(cat /run/secrets/HF_TOKEN)" -o wikihop.db
|
| 57 |
|
| 58 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
| 59 |
|
README.md
CHANGED
|
@@ -9,4 +9,15 @@ hf_oauth: true
|
|
| 9 |
hf_oauth_scopes:
|
| 10 |
- inference-api
|
| 11 |
- email
|
| 12 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
hf_oauth_scopes:
|
| 10 |
- inference-api
|
| 11 |
- email
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Can you wikirace faster than an LLM? 🏁
|
| 15 |
+
|
| 16 |
+
Go head-to-head with Qwen, Gemma, and DeepSeek on the [Huggingface Space](https://huggingface.co/spaces/HuggingFaceTB/Wikispeedia)
|
| 17 |
+
|
| 18 |
+
<!-- add gifs -->
|
| 19 |
+

|
| 20 |
+
|
| 21 |
+
Or run 100s of agents on any model in parallel for efficient evaluations [see README](parallel_eval)
|
| 22 |
+
|
| 23 |
+

|
index.html
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>
|
| 8 |
</head>
|
| 9 |
<body>
|
| 10 |
<div id="root"></div>
|
|
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>WikiRacing LLMs</title>
|
| 8 |
</head>
|
| 9 |
<body>
|
| 10 |
<div id="root"></div>
|
parallel_eval/README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Setup env
|
| 2 |
+
|
| 3 |
+
```bash
|
| 4 |
+
uv venv
|
| 5 |
+
source .venv/bin/activate
|
| 6 |
+
uv pip install -r requirements.txt
|
| 7 |
+
|
| 8 |
+
# pull wikihop db
|
| 9 |
+
wget https://huggingface.co/datasets/HuggingFaceTB/simplewiki-pruned-text-350k/blob/main/wikihop.db -o wikihop.db
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## Which models does it support?
|
| 13 |
+
Under the hood it uses [LiteLLM](https://github.com/BerriAI/litellm) so you can use any major model (dont forget to export appropriate api key), or host any model on huggingface via [vLLM](https://github.com/vllm-project/vllm).
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Play the game
|
| 17 |
+
```
|
| 18 |
+
# play the game with cli
|
| 19 |
+
python game.py --human --start 'Saint Lucia' --end 'Italy' --db wikihop.db
|
| 20 |
+
|
| 21 |
+
# have the agent play the game (gpt-4o)
|
| 22 |
+
export OPENAI_API_KEY=sk_xxxxx
|
| 23 |
+
python game.py --agent --start 'Saint Lucia' --end 'Italy' --db wikihop.db --model gpt-4o --max-steps 20
|
| 24 |
+
|
| 25 |
+
# run an evaluation suite with qwen3 hosted on vLLM, 200 workers
|
| 26 |
+
python proctor.py --model "hosted_vllm/Qwen/Qwen3-30B-A3B" --api-base "http://localhost:8000/v1" --workers 200
|
| 27 |
+
|
| 28 |
+
# this will produce a `proctor_tmp/proctor_1-final-results.json` that can be visualized in the space, as well as the individual reasoning traces for each run. This is resumable if it is stopped and is idempotent.
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## JQ command to strip out reasoning traces
|
| 32 |
+
This output file will be very large because it contains all the reasoning traces. You can shrink it down and still be able to visualize it with
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
jq '{
|
| 36 |
+
article_list: .article_list,
|
| 37 |
+
num_trials: .num_trials,
|
| 38 |
+
num_workers: .num_workers,
|
| 39 |
+
max_steps: .max_steps,
|
| 40 |
+
agent_settings: .agent_settings,
|
| 41 |
+
runs: [.runs[] | {
|
| 42 |
+
model: .model,
|
| 43 |
+
api_base: .api_base,
|
| 44 |
+
max_links: .max_links,
|
| 45 |
+
max_tries: .max_tries, result: .result,
|
| 46 |
+
start_article: .start_article,
|
| 47 |
+
destination_article: .destination_article,
|
| 48 |
+
steps: [.steps[] | {
|
| 49 |
+
type: .type,
|
| 50 |
+
article: .article,
|
| 51 |
+
metadata: (if .metadata.conversation then
|
| 52 |
+
.metadata | del(.conversation)
|
| 53 |
+
else
|
| 54 |
+
.metadata
|
| 55 |
+
end)
|
| 56 |
+
}]
|
| 57 |
+
}]
|
| 58 |
+
}' proctor_tmp/proctor_1-final-results.json > cleaned_data.json
|
| 59 |
+
```
|
parallel_eval/game.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Dict, Optional
|
| 2 |
+
import sqlite3
|
| 3 |
+
import json
|
| 4 |
+
import litellm
|
| 5 |
+
import re
|
| 6 |
+
import asyncio
|
| 7 |
+
import argparse
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
class SQLiteDB:
|
| 10 |
+
def __init__(self, db_path: str):
|
| 11 |
+
"""Initialize the database with path to SQLite database"""
|
| 12 |
+
self.db_path = db_path
|
| 13 |
+
self.conn = sqlite3.connect(db_path)
|
| 14 |
+
self.conn.row_factory = sqlite3.Row
|
| 15 |
+
self.cursor = self.conn.cursor()
|
| 16 |
+
self._article_count = self._get_article_count()
|
| 17 |
+
print(f"Connected to SQLite database with {self._article_count} articles")
|
| 18 |
+
|
| 19 |
+
def _get_article_count(self):
|
| 20 |
+
self.cursor.execute("SELECT COUNT(*) FROM core_articles")
|
| 21 |
+
return self.cursor.fetchone()[0]
|
| 22 |
+
|
| 23 |
+
@lru_cache(maxsize=8192)
|
| 24 |
+
def get_article_with_links(self, article_title: str) -> Tuple[str, List[str]]:
|
| 25 |
+
self.cursor.execute(
|
| 26 |
+
"SELECT title, links_json FROM core_articles WHERE title = ?",
|
| 27 |
+
(article_title,),
|
| 28 |
+
)
|
| 29 |
+
article = self.cursor.fetchone()
|
| 30 |
+
if not article:
|
| 31 |
+
return None, []
|
| 32 |
+
|
| 33 |
+
links = json.loads(article["links_json"])
|
| 34 |
+
return article["title"], links
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Player:
|
| 38 |
+
def __init__(self, name: str):
|
| 39 |
+
self.name = name
|
| 40 |
+
|
| 41 |
+
async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]:
|
| 42 |
+
print("Link choices:")
|
| 43 |
+
for i, link in enumerate(game_state[-1]["links"]):
|
| 44 |
+
print(f"{i}: {link}")
|
| 45 |
+
|
| 46 |
+
idx = int(input(f"Enter the index of the link you want to select: "))
|
| 47 |
+
return game_state[-1]["links"][idx], {
|
| 48 |
+
"message": f"{self.name} selected link #{i}"
|
| 49 |
+
} # select the first link
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AgentPlayer(Player):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
model: str,
|
| 56 |
+
api_base: str,
|
| 57 |
+
verbose: bool = True,
|
| 58 |
+
max_links=None,
|
| 59 |
+
max_tries=10,
|
| 60 |
+
target_article = None,
|
| 61 |
+
seed = None
|
| 62 |
+
):
|
| 63 |
+
super().__init__(model)
|
| 64 |
+
self.model = model
|
| 65 |
+
self.api_base = api_base
|
| 66 |
+
self.verbose = verbose
|
| 67 |
+
self.max_links = max_links
|
| 68 |
+
self.max_tries = max_tries
|
| 69 |
+
self.target_article = target_article
|
| 70 |
+
self.seed = seed
|
| 71 |
+
|
| 72 |
+
async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]:
|
| 73 |
+
prompt = self.construct_prompt(game_state)
|
| 74 |
+
|
| 75 |
+
conversation = [
|
| 76 |
+
{"role": "user", "content": prompt}
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
for try_number in range(self.max_tries):
|
| 80 |
+
response = await litellm.acompletion(
|
| 81 |
+
model=self.model,
|
| 82 |
+
api_base=self.api_base,
|
| 83 |
+
messages=conversation,
|
| 84 |
+
seed=self.seed
|
| 85 |
+
)
|
| 86 |
+
response = response.choices[0].message.content
|
| 87 |
+
|
| 88 |
+
conversation.append({"role": "assistant", "content": response})
|
| 89 |
+
|
| 90 |
+
answer, message = self._attempt_to_extract_answer(response, maximum_answer=len(game_state[-1]["links"]))
|
| 91 |
+
|
| 92 |
+
# there was a problem with the answer so give the model another chance
|
| 93 |
+
if answer == -1:
|
| 94 |
+
conversation.append({"role": "user", "content": message})
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
assert answer >= 1 and answer <= len(game_state[-1]["links"]), f"Answer {answer} is out of range"
|
| 98 |
+
|
| 99 |
+
# we found an answer so we can return it
|
| 100 |
+
return game_state[-1]["links"][answer-1], {"tries": try_number, "conversation": conversation}
|
| 101 |
+
|
| 102 |
+
# we tried the max number of times and still didn't find an answer
|
| 103 |
+
return -1, {"tries": self.max_tries, "conversation": conversation}
|
| 104 |
+
|
| 105 |
+
def construct_prompt(self, game_state: List[Dict]) -> str:
|
| 106 |
+
current = game_state[-1]["article"]
|
| 107 |
+
target = self.target_article
|
| 108 |
+
available_links = game_state[-1]["links"]
|
| 109 |
+
formatted_links = "\n".join([f"{i+1}. {link}" for i, link in enumerate(available_links)])
|
| 110 |
+
path_so_far = [step["article"] for step in game_state]
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
formatted_path = ' -> '.join(path_so_far)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Error formatting path: {e}")
|
| 116 |
+
print(game_state)
|
| 117 |
+
print("Path so far: ", path_so_far)
|
| 118 |
+
raise e
|
| 119 |
+
|
| 120 |
+
return f"""You are playing WikiRun, trying to navigate from one Wikipedia article to another using only links.
|
| 121 |
+
|
| 122 |
+
IMPORTANT: You MUST put your final answer in <answer>NUMBER</answer> tags, where NUMBER is the link number.
|
| 123 |
+
For example, if you want to choose link 3, output <answer>3</answer>.
|
| 124 |
+
|
| 125 |
+
Current article: {current}
|
| 126 |
+
Target article: {target}
|
| 127 |
+
Available links (numbered):
|
| 128 |
+
{formatted_links}
|
| 129 |
+
|
| 130 |
+
Your path so far: {formatted_path}
|
| 131 |
+
|
| 132 |
+
Think about which link is most likely to lead you toward the target article.
|
| 133 |
+
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
|
| 134 |
+
|
| 135 |
+
Remember to format your final answer by explicitly writing out the xml number tags like this: <answer>NUMBER</answer>
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def _attempt_to_extract_answer(self, response: str, maximum_answer: Optional[int] = None) -> Tuple[int, str]:
|
| 139 |
+
'returns -1 and a message if no answer is found'
|
| 140 |
+
|
| 141 |
+
# Extract choice using format <answer>N</answer>
|
| 142 |
+
choice_match = re.search(r"<answer>(\d+)</answer>", response)
|
| 143 |
+
|
| 144 |
+
if choice_match is None:
|
| 145 |
+
return -1, f"No answer found in response. Please respond with a number between 1 and {maximum_answer} in <answer>NUMBER</answer> tags."
|
| 146 |
+
|
| 147 |
+
# check if there are multiple answers
|
| 148 |
+
multiple_answers = re.findall(r"<answer>(\d+)</answer>", response)
|
| 149 |
+
if len(multiple_answers) > 1:
|
| 150 |
+
return -1, "Multiple answers found in response. Please respond with just one."
|
| 151 |
+
|
| 152 |
+
answer = choice_match.group(1)
|
| 153 |
+
|
| 154 |
+
# try to convert to int
|
| 155 |
+
try:
|
| 156 |
+
answer = int(answer)
|
| 157 |
+
except ValueError:
|
| 158 |
+
return -1, f"You answered with {answer} but it could not be converted to an integer. Please respond with a number between 1 and {maximum_answer}."
|
| 159 |
+
|
| 160 |
+
# check if the answer is too high or too low
|
| 161 |
+
if answer > maximum_answer or answer < 1:
|
| 162 |
+
return -1, f"You answered with {answer} but you have to select a number between 1 and {maximum_answer}."
|
| 163 |
+
|
| 164 |
+
return answer, "" # we found an answer so we don't need to return a message
|
| 165 |
+
|
| 166 |
+
class Game:
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
start_article: str,
|
| 170 |
+
target_article: str,
|
| 171 |
+
db: SQLiteDB,
|
| 172 |
+
max_allowed_steps: int,
|
| 173 |
+
player: Player,
|
| 174 |
+
verbose: bool = True,
|
| 175 |
+
):
|
| 176 |
+
self.start_article = start_article
|
| 177 |
+
self.target_article = target_article
|
| 178 |
+
self.db = db
|
| 179 |
+
self.max_allowed_steps = max_allowed_steps
|
| 180 |
+
self.steps = []
|
| 181 |
+
self.steps_taken = 0
|
| 182 |
+
self.player = player
|
| 183 |
+
self.verbose = verbose
|
| 184 |
+
# Ensure the player knows the target article
|
| 185 |
+
if isinstance(self.player, AgentPlayer):
|
| 186 |
+
self.player.target_article = self.target_article
|
| 187 |
+
|
| 188 |
+
async def run(self):
|
| 189 |
+
|
| 190 |
+
if self.verbose:
|
| 191 |
+
print(f"Starting game from {self.start_article} to {self.target_article}")
|
| 192 |
+
|
| 193 |
+
# get the start article
|
| 194 |
+
_, links = self.db.get_article_with_links(self.start_article)
|
| 195 |
+
|
| 196 |
+
self.steps.append(
|
| 197 |
+
{
|
| 198 |
+
"type": "start",
|
| 199 |
+
"article": self.start_article,
|
| 200 |
+
"links": links,
|
| 201 |
+
"metadata": {"message": "Game started"},
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# while the current article is not the target article and the number of steps taken is less than the max allowed steps
|
| 206 |
+
while self.steps_taken < self.max_allowed_steps:
|
| 207 |
+
self.steps_taken += 1
|
| 208 |
+
|
| 209 |
+
# Await the async player move
|
| 210 |
+
player_move, metadata = await self.player.get_move(self.steps)
|
| 211 |
+
|
| 212 |
+
# player couldn't select a valid link
|
| 213 |
+
if player_move == -1:
|
| 214 |
+
self.steps.append(
|
| 215 |
+
{"type": "lose", "article": player_move, "metadata": metadata}
|
| 216 |
+
)
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
if self.verbose:
|
| 220 |
+
print(f" -> Step {self.steps_taken}: {player_move}")
|
| 221 |
+
# input("Press Enter to continue...")
|
| 222 |
+
|
| 223 |
+
# if we found it its over
|
| 224 |
+
if player_move == self.target_article:
|
| 225 |
+
self.steps.append(
|
| 226 |
+
{"type": "win", "article": player_move, "metadata": metadata}
|
| 227 |
+
)
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
# if not lets get the next article
|
| 231 |
+
_, links = self.db.get_article_with_links(player_move)
|
| 232 |
+
|
| 233 |
+
if len(links) == 0:
|
| 234 |
+
self.steps.append(
|
| 235 |
+
{"type": "lose", "article": player_move, "metadata": metadata}
|
| 236 |
+
)
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
self.steps.append(
|
| 240 |
+
{
|
| 241 |
+
"type": "move",
|
| 242 |
+
"article": player_move,
|
| 243 |
+
"links": links,
|
| 244 |
+
"metadata": metadata,
|
| 245 |
+
}
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return self.steps
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
parser = argparse.ArgumentParser(description="Play the WikiRun game")
|
| 253 |
+
|
| 254 |
+
# Add mutual exclusion group for player type
|
| 255 |
+
player_group = parser.add_mutually_exclusive_group(required=True)
|
| 256 |
+
player_group.add_argument("--human", action="store_true", help="Play as a human")
|
| 257 |
+
player_group.add_argument("--agent", action="store_true", help="Use an AI agent to play")
|
| 258 |
+
|
| 259 |
+
# Game parameters
|
| 260 |
+
parser.add_argument("--start", type=str, default="British Library", help="Starting article title")
|
| 261 |
+
parser.add_argument("--end", type=str, default="Saint Lucia", help="Target article title")
|
| 262 |
+
parser.add_argument("--db", type=str, required=True, help="Path to SQLite database")
|
| 263 |
+
parser.add_argument("--max-steps", type=int, default=10, help="Maximum number of steps allowed (default: 10)")
|
| 264 |
+
|
| 265 |
+
# Agent parameters (only used with --agent)
|
| 266 |
+
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for the agent (default: gpt-4o)")
|
| 267 |
+
parser.add_argument("--api-base", type=str, default="https://api.openai.com/v1",
|
| 268 |
+
help="API base URL (default: https://api.openai.com/v1)")
|
| 269 |
+
parser.add_argument("--max-links", type=int, default=200, help="Maximum number of links to consider (default: 200)")
|
| 270 |
+
parser.add_argument("--max-tries", type=int, default=3, help="Maximum number of tries for the agent (default: 3)")
|
| 271 |
+
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
|
| 272 |
+
|
| 273 |
+
args = parser.parse_args()
|
| 274 |
+
|
| 275 |
+
# Initialize the database
|
| 276 |
+
db = SQLiteDB(args.db)
|
| 277 |
+
|
| 278 |
+
# Initialize the player based on the argument
|
| 279 |
+
if args.human:
|
| 280 |
+
player = Player("Human")
|
| 281 |
+
else: # args.agent is True
|
| 282 |
+
player = AgentPlayer(
|
| 283 |
+
model=args.model,
|
| 284 |
+
api_base=args.api_base,
|
| 285 |
+
verbose=True,
|
| 286 |
+
max_links=args.max_links,
|
| 287 |
+
max_tries=args.max_tries,
|
| 288 |
+
target_article=args.end,
|
| 289 |
+
seed=args.seed
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Create and run the game
|
| 293 |
+
game = Game(
|
| 294 |
+
start_article=args.start,
|
| 295 |
+
target_article=args.end,
|
| 296 |
+
db=db,
|
| 297 |
+
max_allowed_steps=args.max_steps,
|
| 298 |
+
player=player,
|
| 299 |
+
verbose=True
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
steps = asyncio.run(game.run())
|
| 303 |
+
|
| 304 |
+
print(f"Game over in {len(steps)} steps")
|
| 305 |
+
for i, step in enumerate(steps):
|
| 306 |
+
print(f"Step {i}: {step['type']}")
|
| 307 |
+
print(f" Article: {step['article']}")
|
| 308 |
+
print(f" Links: {step.get('links', [])}")
|
| 309 |
+
print(f" Metadata: {step.get('metadata', {})}")
|
| 310 |
+
print("\n\n")
|
parallel_eval/proctor.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from game import AgentPlayer, SQLiteDB, Game
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import asyncio
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Proctor:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
article_list: list[tuple[str, str]],
|
| 12 |
+
num_trials: int,
|
| 13 |
+
num_workers: int,
|
| 14 |
+
max_steps: int,
|
| 15 |
+
agent_settings: dict,
|
| 16 |
+
db_path: str,
|
| 17 |
+
verbose: bool = True,
|
| 18 |
+
output_dir: str = "./proctor_tmp",
|
| 19 |
+
proctor_id: str = "proctor_1",
|
| 20 |
+
starting_seed: int = 42,
|
| 21 |
+
):
|
| 22 |
+
self.article_list = article_list
|
| 23 |
+
self.num_trials = num_trials
|
| 24 |
+
self.num_workers = num_workers
|
| 25 |
+
self.max_steps = max_steps
|
| 26 |
+
self.agent_settings = agent_settings
|
| 27 |
+
self.db_path = db_path
|
| 28 |
+
self.verbose = verbose
|
| 29 |
+
self.output_dir = output_dir
|
| 30 |
+
self.proctor_id = proctor_id
|
| 31 |
+
self.db = SQLiteDB(self.db_path)
|
| 32 |
+
self.starting_seed = starting_seed
|
| 33 |
+
|
| 34 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
self.runs = []
|
| 37 |
+
|
| 38 |
+
self.setup_runs()
|
| 39 |
+
|
| 40 |
+
def setup_runs(self):
|
| 41 |
+
for start in self.article_list:
|
| 42 |
+
for destination in self.article_list:
|
| 43 |
+
if start == destination:
|
| 44 |
+
continue
|
| 45 |
+
for n in range(self.num_trials):
|
| 46 |
+
run_id = f"{self.proctor_id}_{start}_{destination}_{n}"
|
| 47 |
+
self.runs.append(
|
| 48 |
+
Run(
|
| 49 |
+
start,
|
| 50 |
+
destination,
|
| 51 |
+
self.max_steps,
|
| 52 |
+
self.agent_settings,
|
| 53 |
+
self.db,
|
| 54 |
+
self.output_dir,
|
| 55 |
+
self.verbose,
|
| 56 |
+
run_id,
|
| 57 |
+
self.starting_seed + n,
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
print(f"Setup run {run_id}")
|
| 61 |
+
|
| 62 |
+
async def run(self):
|
| 63 |
+
semaphore = asyncio.Semaphore(self.num_workers)
|
| 64 |
+
tasks = []
|
| 65 |
+
|
| 66 |
+
async def run_with_semaphore(run_instance):
|
| 67 |
+
async with semaphore:
|
| 68 |
+
if self.verbose:
|
| 69 |
+
print(f"Starting run {run_instance.id}")
|
| 70 |
+
await run_instance.run()
|
| 71 |
+
if self.verbose:
|
| 72 |
+
print(f"Finished run {run_instance.id}")
|
| 73 |
+
|
| 74 |
+
for run_instance in self.runs:
|
| 75 |
+
tasks.append(asyncio.create_task(run_with_semaphore(run_instance)))
|
| 76 |
+
|
| 77 |
+
await asyncio.gather(*tasks)
|
| 78 |
+
|
| 79 |
+
self.analyze_runs()
|
| 80 |
+
|
| 81 |
+
def analyze_runs(self):
|
| 82 |
+
"""We need to analze all the runs into a .json"""
|
| 83 |
+
final_results = {
|
| 84 |
+
"article_list": self.article_list,
|
| 85 |
+
"num_trials": self.num_trials,
|
| 86 |
+
"num_workers": self.num_workers,
|
| 87 |
+
"max_steps": self.max_steps,
|
| 88 |
+
"agent_settings": self.agent_settings,
|
| 89 |
+
"runs": [],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
win_count = 0
|
| 93 |
+
lose_count = 0
|
| 94 |
+
hops_distribution = []
|
| 95 |
+
|
| 96 |
+
for run in self.runs:
|
| 97 |
+
with open(run.output_file, "r") as f:
|
| 98 |
+
result = json.load(f)
|
| 99 |
+
final_results["runs"].append(result)
|
| 100 |
+
if result["result"] == "win":
|
| 101 |
+
win_count += 1
|
| 102 |
+
hops_distribution.append(len(result["steps"]) - 1)
|
| 103 |
+
else:
|
| 104 |
+
lose_count += 1
|
| 105 |
+
|
| 106 |
+
final_results["hops_distribution"] = hops_distribution
|
| 107 |
+
final_results["average_hops"] = sum(hops_distribution) / len(hops_distribution)
|
| 108 |
+
final_results["win_rate"] = win_count / len(self.runs)
|
| 109 |
+
final_results["lose_rate"] = lose_count / len(self.runs)
|
| 110 |
+
|
| 111 |
+
with open(f"{self.output_dir}/{self.proctor_id}-final-results.json", "w") as f:
|
| 112 |
+
json.dump(final_results, f, indent=4)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Run:
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
start_article: str,
|
| 119 |
+
destination_article: str,
|
| 120 |
+
max_steps: int,
|
| 121 |
+
agent_settings: dict,
|
| 122 |
+
db: SQLiteDB,
|
| 123 |
+
output_dir: str,
|
| 124 |
+
verbose: bool,
|
| 125 |
+
id: str,
|
| 126 |
+
seed: int,
|
| 127 |
+
):
|
| 128 |
+
self.start_article = start_article
|
| 129 |
+
self.destination_article = destination_article
|
| 130 |
+
self.max_steps = max_steps
|
| 131 |
+
self.agent_settings = agent_settings
|
| 132 |
+
self.db = db
|
| 133 |
+
self.output_dir = output_dir
|
| 134 |
+
self.verbose = verbose
|
| 135 |
+
self.id = id
|
| 136 |
+
self.seed = seed
|
| 137 |
+
|
| 138 |
+
self.output_file = f"{self.output_dir}/run_{self.id}.json"
|
| 139 |
+
|
| 140 |
+
async def run(self):
|
| 141 |
+
if os.path.exists(self.output_file):
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
player = AgentPlayer(
|
| 145 |
+
model=self.agent_settings["model"],
|
| 146 |
+
api_base=self.agent_settings["api_base"],
|
| 147 |
+
max_links=self.agent_settings["max_links"],
|
| 148 |
+
max_tries=self.agent_settings["max_tries"],
|
| 149 |
+
verbose=False,
|
| 150 |
+
seed=self.seed,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
game = Game(
|
| 154 |
+
self.start_article,
|
| 155 |
+
self.destination_article,
|
| 156 |
+
self.db,
|
| 157 |
+
self.max_steps,
|
| 158 |
+
player,
|
| 159 |
+
verbose=False,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
steps = await game.run()
|
| 163 |
+
|
| 164 |
+
output = {
|
| 165 |
+
"model": self.agent_settings["model"],
|
| 166 |
+
"api_base": self.agent_settings["api_base"],
|
| 167 |
+
"max_links": self.agent_settings["max_links"],
|
| 168 |
+
"max_tries": self.agent_settings["max_tries"],
|
| 169 |
+
"start_article": self.start_article,
|
| 170 |
+
"destination_article": self.destination_article,
|
| 171 |
+
"steps": steps,
|
| 172 |
+
"seed": self.seed,
|
| 173 |
+
"result": steps[-1]["type"],
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
with open(self.output_file, "w") as f:
|
| 177 |
+
json.dump(output, f, indent=4)
|
| 178 |
+
|
| 179 |
+
print(f"Run {self.id} completed in {len(steps)} steps")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
parser = argparse.ArgumentParser(description="Run parallel Wikispeedia evaluations")
|
| 184 |
+
parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for agent")
|
| 185 |
+
parser.add_argument("--api-base", type=str, default=None, help="API base URL for hosted models")
|
| 186 |
+
parser.add_argument("--workers", type=int, default=20, help="Number of parallel workers")
|
| 187 |
+
parser.add_argument("--trials", type=int, default=1, help="Number of trials per start-destination pair")
|
| 188 |
+
parser.add_argument("--max-steps", type=int, default=20, help="Maximum steps per game")
|
| 189 |
+
parser.add_argument("--max-links", type=int, default=200, help="Maximum links per page for agent")
|
| 190 |
+
parser.add_argument("--max-tries", type=int, default=3, help="Maximum retries for agent")
|
| 191 |
+
parser.add_argument("--db-path", type=str, default="wikihop.db", help="Path to the wikihop database")
|
| 192 |
+
parser.add_argument("--output-dir", type=str, default="./proctor_tmp", help="Directory for output files")
|
| 193 |
+
parser.add_argument("--proctor-id", type=str, default="proctor_1", help="Unique identifier for this proctor run")
|
| 194 |
+
parser.add_argument("--seed", type=int, default=42, help="Starting random seed")
|
| 195 |
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
| 196 |
+
parser.add_argument("--article-list", type=str, default="supernodes.json",
|
| 197 |
+
help="Path to JSON file with list of articles to test")
|
| 198 |
+
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
|
| 201 |
+
# check if db exists
|
| 202 |
+
if not os.path.exists(args.db_path):
|
| 203 |
+
raise FileNotFoundError(f"Database file not found at {args.db_path}")
|
| 204 |
+
|
| 205 |
+
# check if article list exists
|
| 206 |
+
if not os.path.exists(args.article_list):
|
| 207 |
+
raise FileNotFoundError(f"Article list file not found at {args.article_list}")
|
| 208 |
+
|
| 209 |
+
# Read article list from file
|
| 210 |
+
with open(args.article_list, "r") as f:
|
| 211 |
+
article_list = json.load(f)
|
| 212 |
+
|
| 213 |
+
agent_settings = {
|
| 214 |
+
"model": args.model,
|
| 215 |
+
"api_base": args.api_base,
|
| 216 |
+
"max_links": args.max_links,
|
| 217 |
+
"max_tries": args.max_tries,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
proctor = Proctor(
|
| 221 |
+
article_list=article_list,
|
| 222 |
+
num_trials=args.trials,
|
| 223 |
+
num_workers=args.workers,
|
| 224 |
+
max_steps=args.max_steps,
|
| 225 |
+
agent_settings=agent_settings,
|
| 226 |
+
db_path=args.db_path,
|
| 227 |
+
verbose=args.verbose,
|
| 228 |
+
output_dir=args.output_dir,
|
| 229 |
+
proctor_id=args.proctor_id,
|
| 230 |
+
starting_seed=args.seed,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
asyncio.run(proctor.run())
|
parallel_eval/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
litellm>=1.10.0
|
| 2 |
+
asyncio
|
| 3 |
+
tqdm
|
| 4 |
+
sqlite3-wrapper
|
| 5 |
+
aiohttp
|
parallel_eval/supernodes.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"Soviet Union",
|
| 3 |
+
"Frank Lloyd Wright",
|
| 4 |
+
"Major League Baseball",
|
| 5 |
+
"R (programming language)",
|
| 6 |
+
"Hinduism",
|
| 7 |
+
"Singapore General Hospital",
|
| 8 |
+
"Nepenthes",
|
| 9 |
+
"Google AI",
|
| 10 |
+
"Freedom, Pennsylvania",
|
| 11 |
+
"Iron Man 3",
|
| 12 |
+
"Central Bank of Nigeria",
|
| 13 |
+
"Pok\u00e9mon",
|
| 14 |
+
"Nintendo",
|
| 15 |
+
"Bachelor of Arts",
|
| 16 |
+
"Polynesian languages",
|
| 17 |
+
"France",
|
| 18 |
+
"Jennifer Aniston"
|
| 19 |
+
]
|
src/components/viewer-tab.tsx
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import q3Results from "../../results/qwen3.json"
|
| 4 |
import q3_30B_A3B_Results from "../../results/qwen3-30B-A3-results.json"
|
| 5 |
// import mockResults from "../../qwen3-final-results.json"
|
| 6 |
-
import { useMemo, useState, useEffect } from "react";
|
| 7 |
import { Card } from "@/components/ui/card";
|
| 8 |
import ForceDirectedGraph from "@/components/force-directed-graph";
|
| 9 |
import RunsList from "@/components/runs-list";
|
|
@@ -16,8 +16,10 @@ import {
|
|
| 16 |
} from "@/components/ui/select";
|
| 17 |
import { Run as ForceGraphRun } from "@/components/reasoning-trace";
|
| 18 |
import { Badge } from "@/components/ui/badge";
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
const
|
| 21 |
"Qwen3-14B": q3Results,
|
| 22 |
"Qwen3-30B-A3B": q3_30B_A3B_Results,
|
| 23 |
}
|
|
@@ -51,10 +53,12 @@ export default function ViewerTab({
|
|
| 51 |
const [runs, setRuns] = useState<Run[]>([]);
|
| 52 |
const [selectedModel, setSelectedModel] = useState<string>("Qwen3-14B");
|
| 53 |
const [modelStats, setModelStats] = useState<ModelStats | null>(null);
|
|
|
|
|
|
|
| 54 |
|
| 55 |
useEffect(() => {
|
| 56 |
// Convert the model data to the format expected by RunsList
|
| 57 |
-
const convertedRuns = models[selectedModel]
|
| 58 |
start_article: string;
|
| 59 |
destination_article: string;
|
| 60 |
steps: { type: string; article: string }[];
|
|
@@ -64,7 +68,7 @@ export default function ViewerTab({
|
|
| 64 |
destination_article: run.destination_article,
|
| 65 |
steps: run.steps.map((step: { article: string }) => step.article),
|
| 66 |
result: run.result
|
| 67 |
-
}));
|
| 68 |
setRuns(convertedRuns);
|
| 69 |
|
| 70 |
// Calculate model statistics
|
|
@@ -105,7 +109,7 @@ export default function ViewerTab({
|
|
| 105 |
minSteps,
|
| 106 |
maxSteps
|
| 107 |
});
|
| 108 |
-
}, [selectedModel]);
|
| 109 |
|
| 110 |
const handleRunSelect = (runId: number) => {
|
| 111 |
setSelectedRun(runId);
|
|
@@ -124,6 +128,49 @@ export default function ViewerTab({
|
|
| 124 |
}));
|
| 125 |
}, [filterRuns]);
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
return (
|
| 128 |
<div className="grid grid-cols-1 md:grid-cols-12 gap-4 h-[calc(100vh-200px)] max-h-[calc(100vh-200px)] overflow-hidden p-2">
|
| 129 |
<Card className="p-3 col-span-12 row-start-1">
|
|
@@ -143,6 +190,23 @@ export default function ViewerTab({
|
|
| 143 |
</Select>
|
| 144 |
</div>
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
{modelStats && (
|
| 147 |
<div className="flex flex-wrap gap-1.5 items-center">
|
| 148 |
<Badge variant="outline" className="px-2 py-0.5 flex gap-1 items-center">
|
|
|
|
| 3 |
import q3Results from "../../results/qwen3.json"
|
| 4 |
import q3_30B_A3B_Results from "../../results/qwen3-30B-A3-results.json"
|
| 5 |
// import mockResults from "../../qwen3-final-results.json"
|
| 6 |
+
import { useMemo, useState, useEffect, useRef } from "react";
|
| 7 |
import { Card } from "@/components/ui/card";
|
| 8 |
import ForceDirectedGraph from "@/components/force-directed-graph";
|
| 9 |
import RunsList from "@/components/runs-list";
|
|
|
|
| 16 |
} from "@/components/ui/select";
|
| 17 |
import { Run as ForceGraphRun } from "@/components/reasoning-trace";
|
| 18 |
import { Badge } from "@/components/ui/badge";
|
| 19 |
+
import { Button } from "@/components/ui/button";
|
| 20 |
+
import { UploadIcon } from "lucide-react";
|
| 21 |
|
| 22 |
+
const defaultModels = {
|
| 23 |
"Qwen3-14B": q3Results,
|
| 24 |
"Qwen3-30B-A3B": q3_30B_A3B_Results,
|
| 25 |
}
|
|
|
|
| 53 |
const [runs, setRuns] = useState<Run[]>([]);
|
| 54 |
const [selectedModel, setSelectedModel] = useState<string>("Qwen3-14B");
|
| 55 |
const [modelStats, setModelStats] = useState<ModelStats | null>(null);
|
| 56 |
+
const [models, setModels] = useState(defaultModels);
|
| 57 |
+
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 58 |
|
| 59 |
useEffect(() => {
|
| 60 |
// Convert the model data to the format expected by RunsList
|
| 61 |
+
const convertedRuns = models[selectedModel]?.runs?.map((run: {
|
| 62 |
start_article: string;
|
| 63 |
destination_article: string;
|
| 64 |
steps: { type: string; article: string }[];
|
|
|
|
| 68 |
destination_article: run.destination_article,
|
| 69 |
steps: run.steps.map((step: { article: string }) => step.article),
|
| 70 |
result: run.result
|
| 71 |
+
})) || [];
|
| 72 |
setRuns(convertedRuns);
|
| 73 |
|
| 74 |
// Calculate model statistics
|
|
|
|
| 109 |
minSteps,
|
| 110 |
maxSteps
|
| 111 |
});
|
| 112 |
+
}, [selectedModel, models]);
|
| 113 |
|
| 114 |
const handleRunSelect = (runId: number) => {
|
| 115 |
setSelectedRun(runId);
|
|
|
|
| 128 |
}));
|
| 129 |
}, [filterRuns]);
|
| 130 |
|
| 131 |
+
const handleFileUpload = (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 132 |
+
const file = event.target.files?.[0];
|
| 133 |
+
if (!file) return;
|
| 134 |
+
|
| 135 |
+
const reader = new FileReader();
|
| 136 |
+
reader.onload = (e) => {
|
| 137 |
+
try {
|
| 138 |
+
const jsonData = JSON.parse(e.target?.result as string);
|
| 139 |
+
|
| 140 |
+
// Validate the JSON structure has the required fields
|
| 141 |
+
if (!jsonData.runs || !Array.isArray(jsonData.runs)) {
|
| 142 |
+
alert("Invalid JSON format. File must contain a 'runs' array.");
|
| 143 |
+
return;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// Create a filename-based model name, removing extension and path
|
| 147 |
+
const fileName = file.name.replace(/\.[^/.]+$/, "");
|
| 148 |
+
const modelName = `Custom: ${fileName}`;
|
| 149 |
+
|
| 150 |
+
// Add the new model to the models object
|
| 151 |
+
setModels(prev => ({
|
| 152 |
+
...prev,
|
| 153 |
+
[modelName]: jsonData
|
| 154 |
+
}));
|
| 155 |
+
|
| 156 |
+
// Select the newly added model
|
| 157 |
+
setSelectedModel(modelName);
|
| 158 |
+
} catch (error) {
|
| 159 |
+
alert(`Error parsing JSON file: ${error.message}`);
|
| 160 |
+
}
|
| 161 |
+
};
|
| 162 |
+
reader.readAsText(file);
|
| 163 |
+
|
| 164 |
+
// Reset the file input
|
| 165 |
+
if (fileInputRef.current) {
|
| 166 |
+
fileInputRef.current.value = '';
|
| 167 |
+
}
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
+
const handleUploadClick = () => {
|
| 171 |
+
fileInputRef.current?.click();
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
return (
|
| 175 |
<div className="grid grid-cols-1 md:grid-cols-12 gap-4 h-[calc(100vh-200px)] max-h-[calc(100vh-200px)] overflow-hidden p-2">
|
| 176 |
<Card className="p-3 col-span-12 row-start-1">
|
|
|
|
| 190 |
</Select>
|
| 191 |
</div>
|
| 192 |
|
| 193 |
+
<Button
|
| 194 |
+
variant="outline"
|
| 195 |
+
size="sm"
|
| 196 |
+
className="flex items-center gap-1"
|
| 197 |
+
onClick={handleUploadClick}
|
| 198 |
+
>
|
| 199 |
+
<UploadIcon size={14} />
|
| 200 |
+
<span>Upload JSON</span>
|
| 201 |
+
<input
|
| 202 |
+
type="file"
|
| 203 |
+
ref={fileInputRef}
|
| 204 |
+
accept=".json"
|
| 205 |
+
className="hidden"
|
| 206 |
+
onChange={handleFileUpload}
|
| 207 |
+
/>
|
| 208 |
+
</Button>
|
| 209 |
+
|
| 210 |
{modelStats && (
|
| 211 |
<div className="flex flex-wrap gap-1.5 items-center">
|
| 212 |
<Badge variant="outline" className="px-2 py-0.5 flex gap-1 items-center">
|