Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
ameyasunilm commited on
Commit
1460916
·
verified ·
1 Parent(s): dbe2b5b

Upload streaming tool call parser python file for vLLM

Browse files
Files changed (1) hide show
  1. nemotron_toolcall_parser_streaming.py +236 -0
nemotron_toolcall_parser_streaming.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections.abc import Sequence
4
+ from typing import Union, Optional
5
+
6
+ import partial_json_parser
7
+
8
+ from vllm.entrypoints.openai.protocol import (
9
+ ChatCompletionRequest,
10
+ DeltaFunctionCall,
11
+ DeltaMessage,
12
+ DeltaToolCall,
13
+ ExtractedToolCallInformation,
14
+ FunctionCall,
15
+ ToolCall,
16
+ )
17
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
+ ToolParser,
19
+ ToolParserManager,
20
+ )
21
+ from vllm.logger import init_logger
22
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
23
+ from vllm.utils import random_uuid
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ @ToolParserManager.register_module("nemotron_json")
29
+ class NemotronJSONToolParser(ToolParser):
30
+
31
+ def __init__(self, tokenizer: AnyTokenizer):
32
+ super().__init__(tokenizer)
33
+
34
+ # Streaming state tracking
35
+ self.current_tool_name_sent: bool = False
36
+ self.prev_tool_call_arr: list[dict] = []
37
+ self.current_tool_id: int = -1
38
+ self.streamed_args_for_tool: list[str] = []
39
+ self.tool_call_ids: list[str] = [] # Track IDs for each tool call
40
+
41
+ # Track what we've sent so far in streaming
42
+ self.sent_tool_calls_count: int = 0
43
+ self.sent_args_length: dict[int, int] = {} # tool_idx -> length of args sent
44
+
45
+ self.tool_call_start_token: str = "<TOOLCALL>"
46
+ self.tool_call_end_token: str = "</TOOLCALL>"
47
+
48
+ self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
49
+
50
+ def extract_tool_calls(
51
+ self,
52
+ model_output: str,
53
+ request: ChatCompletionRequest,
54
+ ) -> ExtractedToolCallInformation:
55
+ """Extract tool calls from non-streaming (complete) output."""
56
+
57
+ if self.tool_call_start_token not in model_output:
58
+ return ExtractedToolCallInformation(
59
+ tools_called=False,
60
+ tool_calls=[],
61
+ content=model_output,
62
+ )
63
+
64
+ try:
65
+ # Try to extract complete <TOOLCALL>...</TOOLCALL> blocks
66
+ tool_call_matches = self.tool_call_regex.findall(model_output)
67
+
68
+ if tool_call_matches:
69
+ # Complete tool call block found
70
+ str_tool_calls = tool_call_matches[0].strip()
71
+ else:
72
+ # Incomplete - extract everything after <TOOLCALL>
73
+ start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token)
74
+ str_tool_calls = model_output[start_idx:].strip()
75
+
76
+ # Ensure array brackets
77
+ if not str_tool_calls.startswith("["):
78
+ str_tool_calls = "[" + str_tool_calls
79
+ if not str_tool_calls.endswith("]"):
80
+ str_tool_calls = str_tool_calls + "]"
81
+
82
+ # Use partial JSON parser for incomplete JSON
83
+ json_tool_calls = partial_json_parser.loads(str_tool_calls)
84
+
85
+ if not isinstance(json_tool_calls, list):
86
+ raise ValueError("Tool calls must be a list")
87
+
88
+ tool_calls = []
89
+
90
+ for tool_call in json_tool_calls:
91
+ if not isinstance(tool_call, dict):
92
+ continue
93
+ try:
94
+ tool_calls.append(ToolCall(
95
+ type="function",
96
+ function=FunctionCall(
97
+ name=tool_call.get("name", ""),
98
+ arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \
99
+ if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")),
100
+ ),
101
+ ))
102
+ except Exception as e:
103
+ logger.warning(f"Failed to parse tool call: {e}")
104
+ continue
105
+
106
+ content = model_output[:model_output.find(self.tool_call_start_token)].strip()
107
+
108
+ return ExtractedToolCallInformation(
109
+ tools_called=True if tool_calls else False,
110
+ tool_calls=tool_calls,
111
+ content=content if content else None,
112
+ )
113
+
114
+ except Exception as e:
115
+ logger.exception(f"Error extracting tool calls. Response: {model_output}")
116
+ return ExtractedToolCallInformation(
117
+ tools_called=False,
118
+ tool_calls=[],
119
+ content=model_output,
120
+ )
121
+
122
+ def extract_tool_calls_streaming(
123
+ self,
124
+ previous_text: str,
125
+ current_text: str,
126
+ delta_text: str,
127
+ previous_token_ids: Sequence[int],
128
+ current_token_ids: Sequence[int],
129
+ delta_token_ids: Sequence[int],
130
+ request: ChatCompletionRequest,
131
+ ) -> Union[DeltaMessage, None]:
132
+ """Extract tool calls from streaming output.
133
+
134
+ This incrementally parses the <TOOLCALL> JSON as it streams in,
135
+ sending delta updates for each tool call and its arguments.
136
+ """
137
+
138
+ # Check if we just started tool calling
139
+ if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text:
140
+ # First time seeing <TOOLCALL>, return content before it
141
+ content_before = delta_text.split(self.tool_call_start_token)[0]
142
+ if content_before:
143
+ return DeltaMessage(content=content_before)
144
+ # Start of tool call section - no delta yet
145
+ return None
146
+
147
+ # Check if we're not in tool call mode yet
148
+ if self.tool_call_start_token not in current_text:
149
+ # Regular content, no tool calls
150
+ return DeltaMessage(content=delta_text) if delta_text else None
151
+
152
+ # We're inside <TOOLCALL>...</TOOLCALL>
153
+ # For Nemotron, the entire TOOLCALL block is generated at once
154
+ # So we should only parse when we have the complete </TOOLCALL>
155
+
156
+ # Check if we have the complete tool call block yet
157
+ if self.tool_call_end_token not in current_text:
158
+ # Incomplete tool call, don't send deltas yet
159
+ return None
160
+
161
+ # We have the complete tool call block, parse it
162
+ start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token)
163
+ end_idx = current_text.find(self.tool_call_end_token)
164
+ json_str = current_text[start_idx:end_idx].strip()
165
+
166
+ # Parse the complete JSON
167
+ try:
168
+ # Ensure we have array brackets
169
+ if not json_str.startswith("["):
170
+ json_str = "[" + json_str
171
+ if not json_str.endswith("]"):
172
+ json_str = json_str + "]"
173
+
174
+ # Parse complete JSON
175
+ tool_calls_arr = json.loads(json_str)
176
+
177
+ if not isinstance(tool_calls_arr, list):
178
+ return None
179
+
180
+ # Generate delta updates for new/updated tool calls
181
+ delta_tool_calls = []
182
+
183
+ for idx, tool_call in enumerate(tool_calls_arr):
184
+ if not isinstance(tool_call, dict):
185
+ continue
186
+
187
+ # Ensure we have a tool ID for this call
188
+ while len(self.tool_call_ids) <= idx:
189
+ self.tool_call_ids.append(random_uuid())
190
+
191
+ tool_id = self.tool_call_ids[idx]
192
+ tool_name = tool_call.get("name", "")
193
+ tool_args = tool_call.get("arguments", {})
194
+
195
+ # Convert arguments to JSON string
196
+ if isinstance(tool_args, dict):
197
+ args_str = json.dumps(tool_args, ensure_ascii=False)
198
+ else:
199
+ args_str = str(tool_args)
200
+
201
+ # Check if this is a new tool call
202
+ if idx >= self.sent_tool_calls_count:
203
+ # New tool call - send ID, name, and complete arguments all at once
204
+ # This matches how other models (Llama, etc.) send tool calls
205
+ delta_tool_calls.append(DeltaToolCall(
206
+ index=idx,
207
+ id=tool_id,
208
+ type="function",
209
+ function=DeltaFunctionCall(
210
+ name=tool_name,
211
+ arguments=args_str # Send complete JSON string
212
+ )
213
+ ))
214
+ self.sent_tool_calls_count = idx + 1
215
+ self.sent_args_length[idx] = len(args_str)
216
+
217
+ # NOTE: We don't send incremental updates for arguments
218
+ # because Nemotron generates complete tool calls in one shot
219
+ # Unlike thinking models that stream arguments token-by-token
220
+
221
+ if delta_tool_calls:
222
+ return DeltaMessage(tool_calls=delta_tool_calls)
223
+
224
+ except Exception as e:
225
+ # JSON parsing failed (expected for incomplete JSON)
226
+ logger.debug(f"Partial JSON parse failed (expected during streaming): {e}")
227
+ pass
228
+
229
+ # Check if we just completed the tool calls (end tag in this delta)
230
+ if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text:
231
+ # We just completed - reset state for next potential tool call
232
+ self.sent_tool_calls_count = 0
233
+ self.sent_args_length = {}
234
+ self.tool_call_ids = []
235
+
236
+ return None