fix output parsing to exclude multi-token prefixes

This commit is contained in:
Alex O'Connell
2025-09-21 13:01:54 -04:00
parent dda3f99208
commit d339bfa1bd
3 changed files with 44 additions and 30 deletions

View File

@@ -206,11 +206,11 @@ class GenericOpenAIAPIClient(LocalLLMClient):
if "tool_calls" in choice["delta"]:
tool_calls = []
for call in choice["delta"]["tool_calls"]:
tool_args, to_say = parse_raw_tool_call(
tool_call, to_say = parse_raw_tool_call(
call["function"], llm_api, user_input)
if tool_args:
tool_calls.append(tool_args)
if tool_call:
tool_calls.append(tool_call)
if to_say:
response_text += to_say

View File

@@ -261,7 +261,8 @@ class LocalLLMClient:
message_history[0] = system_prompt
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
for idx in range(max_tool_call_iterations):
# if max tool calls is 0 then we expect to generate the response & tool call in one go
for idx in range(max(1, max_tool_call_iterations)):
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
last_generation_had_tool_calls = False
@@ -296,7 +297,7 @@ class LocalLLMClient:
break
# return an error if we run out of attempt without succeeding
if idx == max_tool_call_iterations - 1:
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
@@ -366,22 +367,33 @@ class LocalLLMClient:
in_tool_call = False
tool_content = ""
last_5_tokens = [] # FIXME: this still returns the first few tokens of the tool call if the prefix is split across chunks
cur_match_length = 0
async for chunk in token_generator:
_LOGGER.debug(f"Handling chunk: {chunk}")
_LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
tool_calls: Optional[List[str | llm.ToolInput | dict]]
content, tool_calls = chunk
if not tool_calls:
tool_calls = []
result = TextGenerationResult(
response=None,
response_streamed=True,
tool_calls=None
)
if content:
last_5_tokens.append(content)
if len(last_5_tokens) > 5:
last_5_tokens.pop(0)
potential_block = "".join(last_5_tokens)
if tool_prefix.startswith("".join(last_5_tokens[-(cur_match_length+1):])):
cur_match_length += 1
else:
# flush the current match length by appending it to content
if cur_match_length > 0:
content += "".join(last_5_tokens[-cur_match_length:])
cur_match_length = 0
if in_tool_call:
tool_content += content
@@ -395,40 +407,40 @@ class LocalLLMClient:
elif tool_prefix in potential_block and not in_tool_call:
in_tool_call = True
last_5_tokens.clear()
elif tool_suffix in potential_block and in_tool_call:
in_tool_call = False
if not llm_api:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api, user_input)
_LOGGER.debug("Tool call parsed: %s", tool_call)
tool_block = tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix)
_LOGGER.debug("Raw tool block extracted: %s", tool_block)
tool_calls.append(tool_block)
tool_content = ""
if tool_call:
result.tool_calls = [tool_call]
if to_say:
content = to_say
else:
content = None
result.response = content
if cur_match_length == 0:
result.response = content
parsed_tool_calls: list[llm.ToolInput] = []
if tool_calls:
if not llm_api:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
result.tool_calls = []
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
result.tool_calls.append(raw_tool_call)
parsed_tool_calls.append(raw_tool_call)
else:
tool_input, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
if tool_input:
result.tool_calls.append(tool_input)
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, user_input)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
if to_say:
result.response = to_say
if not in_thinking and not in_tool_call:
if len(parsed_tool_calls) > 0:
result.tool_calls = parsed_tool_calls
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result
def _async_get_all_exposed_domains(self) -> list[str]:

View File

@@ -48,11 +48,13 @@ CSS3_NAME_TO_RGB = {
class MissingQuantizationException(Exception):
def __init__(self, missing_quant: str, available_quants: list[str]):
super().__init__(missing_quant, available_quants)
self.missing_quant = missing_quant
self.available_quants = available_quants
class MalformedToolCallException(Exception):
def __init__(self, agent_id: str, tool_call_id: str, tool_name: str, tool_args: str, error_msg: str):
super().__init__(agent_id, tool_call_id, tool_name, tool_args, error_msg)
self.agent_id = agent_id
self.tool_call_id = tool_call_id
self.tool_name = tool_name
@@ -437,13 +439,13 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
to_say = parsed_tool_call.pop("to_say", "")
tool_input = llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args=parsed_tool_call,
tool_args=args_dict,
)
else:
to_say = ""
tool_input = llm.ToolInput(
tool_name=parsed_tool_call["name"],
tool_args=parsed_tool_call["arguments"],
tool_name=tool_name,
tool_args=args_dict,
)
return tool_input, to_say