clean up tool response extraction

This commit is contained in:
Alex O'Connell
2025-12-14 00:32:18 -05:00
parent b89a0b44b6
commit c8a5b30e5b
6 changed files with 144 additions and 116 deletions

View File

@@ -155,7 +155,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
session = async_get_clientsession(self.hass)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
response = None
chunk = None
try:
@@ -175,7 +175,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
break
if chunk and chunk.strip():
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
to_say, tool_calls = self._extract_response(json.loads(chunk))
if to_say or tool_calls:
yield to_say, tool_calls
except asyncio.TimeoutError as err:
@@ -183,14 +183,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
except aiohttp.ClientError as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/chat/completions"
return endpoint, request_params
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
list(response_json.keys()), response_json)
@@ -204,16 +204,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
elif response_json["object"] == "chat.completion.chunk":
response_text = choice["delta"].get("content", "")
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
tool_calls = []
for call in choice["delta"]["tool_calls"]:
tool_call, to_say = parse_raw_tool_call(
call["function"], llm_api, agent_id)
if tool_call:
tool_calls.append(tool_call)
if to_say:
response_text += to_say
tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
streamed = True
else:
response_text = choice["text"]
@@ -398,7 +389,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try:
text = self._extract_response(response_json)
return TextGenerationResult(response=text, response_streamed=False)
if not text:
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
except Exception as err:
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")

View File

@@ -464,5 +464,5 @@ class LlamaCppClient(LocalLLMClient):
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
yield content, tool_calls
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())

View File

@@ -155,33 +155,20 @@ class OllamaAPIClient(LocalLLMClient):
return models
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
message = getattr(response_chunk, "message", None)
content = getattr(message, "content", None) if message else None
raw_tool_calls = getattr(message, "tool_calls", None) if message else None
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]:
content = response_chunk.message.content
raw_tool_calls = response_chunk.message.tool_calls
tool_calls: Optional[List[llm.ToolInput]] = None
if raw_tool_calls:
parsed_tool_calls: list[llm.ToolInput] = []
for tool_call in raw_tool_calls:
function = getattr(tool_call, "function", None)
name = getattr(function, "name", None) if function else None
if not name:
continue
arguments = getattr(function, "arguments", None) or {}
if isinstance(arguments, Mapping):
arguments_dict = dict(arguments)
else:
arguments_dict = {"raw": arguments}
parsed_tool_calls.append(llm.ToolInput(tool_name=name, tool_args=arguments_dict))
if parsed_tool_calls:
tool_calls = parsed_tool_calls
if content is None and not tool_calls:
return None, None
# return openai formatted tool calls
tool_calls = [{
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
}
} for call in raw_tool_calls]
else:
tool_calls = None
return content, tool_calls
@@ -226,7 +213,7 @@ class OllamaAPIClient(LocalLLMClient):
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
keep_alive_payload = self._format_keep_alive(keep_alive)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
client = self._build_client(timeout=timeout)
try:
stream = await client.chat(
@@ -249,4 +236,4 @@ class OllamaAPIClient(LocalLLMClient):
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())

View File

@@ -5,7 +5,8 @@ import csv
import logging
import os
import random
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
import re
from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator
from dataclasses import dataclass
from homeassistant.components import conversation
@@ -14,6 +15,7 @@ from homeassistant.components.homeassistant.exposed_entities import async_should
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr, entity
from homeassistant.util import color
@@ -184,28 +186,25 @@ class LocalLLMClient:
_LOGGER.debug("Received chunk: %s", input_chunk)
tool_calls = input_chunk.tool_calls
# fix tool calls for the service tool
if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID:
tool_calls = [
llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args={**tc.tool_args, "service": tc.tool_name}
) for tc in tool_calls
]
if tool_calls and not chat_log.llm_api:
raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided")
yield conversation.AssistantContentDeltaDict(
content=input_chunk.response,
tool_calls=tool_calls
tool_calls=tool_calls
)
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
async def _async_parse_completion(
self, llm_api: llm.APIInstance | None,
async def _async_stream_parse_completion(
self,
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
) -> AsyncGenerator[TextGenerationResult, None]:
"""Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
@@ -236,7 +235,7 @@ class LocalLLMClient:
cur_match_length = 0
async for chunk in token_generator:
# _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
tool_calls: Optional[List[str | llm.ToolInput | dict]]
tool_calls: Optional[List[str | dict]]
content, tool_calls = chunk
if not tool_calls:
@@ -289,25 +288,67 @@ class LocalLLMClient:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
parsed_tool_calls.append(raw_tool_call)
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
if to_say:
result.response = to_say
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
if to_say:
result.response = to_say
if len(parsed_tool_calls) > 0:
result.tool_calls = parsed_tool_calls
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result
async def _async_parse_completion(
self,
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
completion: str | dict) -> TextGenerationResult:
"""Parse completion with tool calls from the backend."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX)
tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL)
if isinstance(completion, dict):
completion = str(completion.get("response", ""))
# Remove thinking blocks, and extract tool calls
tool_calls = tool_regex.findall(completion)
completion = think_regex.sub("", completion)
completion = tool_regex.sub("", completion)
to_say = ""
parsed_tool_calls: list[llm.ToolInput] = []
if len(tool_calls) and not llm_api:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
parsed_tool_calls.append(raw_tool_call)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
return TextGenerationResult(
response=completion + (to_say or ""),
tool_calls=parsed_tool_calls,
)
def _async_get_all_exposed_domains(self) -> list[str]:
"""Gather all exposed domains"""

View File

@@ -32,7 +32,7 @@ from .const import (
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
)
from typing import TYPE_CHECKING
@@ -275,26 +275,29 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
if llm_api.api.id == HOME_LLM_API_ID:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ]
else:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
} for tool in llm_api.tools ]
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
result: List[ChatCompletionTool] = []
for tool in llm_api.tools:
if tool.name == SERVICE_TOOL_NAME:
result.extend([{
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ])
else:
result.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
})
return result
@@ -396,41 +399,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
return tools
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
else:
schema_to_validate = vol.Schema({
# try to validate either format
is_services_tool_call = False
try:
base_schema_to_validate = vol.Schema({
vol.Required("name"): str,
vol.Required("arguments"): vol.Union(str, dict),
})
try:
schema_to_validate(parsed_tool_call)
base_schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
try:
home_llm_schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
home_llm_schema_to_validate(parsed_tool_call)
is_services_tool_call = True
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", ""))
args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
if isinstance(args_dict, str):
if not args_dict.strip():