mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
fix output parsing to exclude multi-token prefixes
This commit is contained in:
@@ -206,11 +206,11 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
if "tool_calls" in choice["delta"]:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
tool_args, to_say = parse_raw_tool_call(
|
||||
tool_call, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api, user_input)
|
||||
|
||||
if tool_args:
|
||||
tool_calls.append(tool_args)
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if to_say:
|
||||
response_text += to_say
|
||||
|
||||
@@ -261,7 +261,8 @@ class LocalLLMClient:
|
||||
message_history[0] = system_prompt
|
||||
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
for idx in range(max_tool_call_iterations):
|
||||
# if max tool calls is 0 then we expect to generate the response & tool call in one go
|
||||
for idx in range(max(1, max_tool_call_iterations)):
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
|
||||
|
||||
last_generation_had_tool_calls = False
|
||||
@@ -296,7 +297,7 @@ class LocalLLMClient:
|
||||
break
|
||||
|
||||
# return an error if we run out of attempt without succeeding
|
||||
if idx == max_tool_call_iterations - 1:
|
||||
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
@@ -366,22 +367,33 @@ class LocalLLMClient:
|
||||
in_tool_call = False
|
||||
tool_content = ""
|
||||
last_5_tokens = [] # FIXME: this still returns the first few tokens of the tool call if the prefix is split across chunks
|
||||
cur_match_length = 0
|
||||
async for chunk in token_generator:
|
||||
_LOGGER.debug(f"Handling chunk: {chunk}")
|
||||
_LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
|
||||
tool_calls: Optional[List[str | llm.ToolInput | dict]]
|
||||
content, tool_calls = chunk
|
||||
|
||||
if not tool_calls:
|
||||
tool_calls = []
|
||||
|
||||
result = TextGenerationResult(
|
||||
response=None,
|
||||
response_streamed=True,
|
||||
tool_calls=None
|
||||
)
|
||||
if content:
|
||||
|
||||
last_5_tokens.append(content)
|
||||
if len(last_5_tokens) > 5:
|
||||
last_5_tokens.pop(0)
|
||||
|
||||
potential_block = "".join(last_5_tokens)
|
||||
if tool_prefix.startswith("".join(last_5_tokens[-(cur_match_length+1):])):
|
||||
cur_match_length += 1
|
||||
else:
|
||||
# flush the current match length by appending it to content
|
||||
if cur_match_length > 0:
|
||||
content += "".join(last_5_tokens[-cur_match_length:])
|
||||
cur_match_length = 0
|
||||
|
||||
if in_tool_call:
|
||||
tool_content += content
|
||||
@@ -395,40 +407,40 @@ class LocalLLMClient:
|
||||
elif tool_prefix in potential_block and not in_tool_call:
|
||||
in_tool_call = True
|
||||
last_5_tokens.clear()
|
||||
|
||||
elif tool_suffix in potential_block and in_tool_call:
|
||||
in_tool_call = False
|
||||
if not llm_api:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api, user_input)
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
tool_block = tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix)
|
||||
_LOGGER.debug("Raw tool block extracted: %s", tool_block)
|
||||
tool_calls.append(tool_block)
|
||||
tool_content = ""
|
||||
|
||||
if tool_call:
|
||||
result.tool_calls = [tool_call]
|
||||
if to_say:
|
||||
content = to_say
|
||||
else:
|
||||
content = None
|
||||
|
||||
result.response = content
|
||||
if cur_match_length == 0:
|
||||
result.response = content
|
||||
|
||||
parsed_tool_calls: list[llm.ToolInput] = []
|
||||
if tool_calls:
|
||||
if not llm_api:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
result.tool_calls = []
|
||||
for raw_tool_call in tool_calls:
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
result.tool_calls.append(raw_tool_call)
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
else:
|
||||
tool_input, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
|
||||
if tool_input:
|
||||
result.tool_calls.append(tool_input)
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, user_input)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
|
||||
if not in_thinking and not in_tool_call:
|
||||
if len(parsed_tool_calls) > 0:
|
||||
result.tool_calls = parsed_tool_calls
|
||||
|
||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||
yield result
|
||||
|
||||
def _async_get_all_exposed_domains(self) -> list[str]:
|
||||
|
||||
@@ -48,11 +48,13 @@ CSS3_NAME_TO_RGB = {
|
||||
|
||||
class MissingQuantizationException(Exception):
|
||||
def __init__(self, missing_quant: str, available_quants: list[str]):
|
||||
super().__init__(missing_quant, available_quants)
|
||||
self.missing_quant = missing_quant
|
||||
self.available_quants = available_quants
|
||||
|
||||
class MalformedToolCallException(Exception):
|
||||
def __init__(self, agent_id: str, tool_call_id: str, tool_name: str, tool_args: str, error_msg: str):
|
||||
super().__init__(agent_id, tool_call_id, tool_name, tool_args, error_msg)
|
||||
self.agent_id = agent_id
|
||||
self.tool_call_id = tool_call_id
|
||||
self.tool_name = tool_name
|
||||
@@ -437,13 +439,13 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
|
||||
to_say = parsed_tool_call.pop("to_say", "")
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=SERVICE_TOOL_NAME,
|
||||
tool_args=parsed_tool_call,
|
||||
tool_args=args_dict,
|
||||
)
|
||||
else:
|
||||
to_say = ""
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=parsed_tool_call["name"],
|
||||
tool_args=parsed_tool_call["arguments"],
|
||||
tool_name=tool_name,
|
||||
tool_args=args_dict,
|
||||
)
|
||||
|
||||
return tool_input, to_say
|
||||
|
||||
Reference in New Issue
Block a user