mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
clean up tool response extraction
This commit is contained in:
@@ -155,7 +155,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
try:
|
||||
@@ -175,7 +175,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
break
|
||||
|
||||
if chunk and chunk.strip():
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk))
|
||||
if to_say or tool_calls:
|
||||
yield to_say, tool_calls
|
||||
except asyncio.TimeoutError as err:
|
||||
@@ -183,14 +183,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
endpoint = "/chat/completions"
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
|
||||
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
|
||||
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
|
||||
list(response_json.keys()), response_json)
|
||||
@@ -204,16 +204,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
elif response_json["object"] == "chat.completion.chunk":
|
||||
response_text = choice["delta"].get("content", "")
|
||||
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
tool_call, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api, agent_id)
|
||||
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if to_say:
|
||||
response_text += to_say
|
||||
tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
|
||||
streamed = True
|
||||
else:
|
||||
response_text = choice["text"]
|
||||
@@ -398,7 +389,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
|
||||
try:
|
||||
text = self._extract_response(response_json)
|
||||
return TextGenerationResult(response=text, response_streamed=False)
|
||||
if not text:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
|
||||
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
|
||||
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")
|
||||
|
||||
@@ -464,5 +464,5 @@ class LlamaCppClient(LocalLLMClient):
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
yield content, tool_calls
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
|
||||
|
||||
|
||||
@@ -155,33 +155,20 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
|
||||
return models
|
||||
|
||||
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
message = getattr(response_chunk, "message", None)
|
||||
content = getattr(message, "content", None) if message else None
|
||||
raw_tool_calls = getattr(message, "tool_calls", None) if message else None
|
||||
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]:
|
||||
content = response_chunk.message.content
|
||||
raw_tool_calls = response_chunk.message.tool_calls
|
||||
|
||||
tool_calls: Optional[List[llm.ToolInput]] = None
|
||||
if raw_tool_calls:
|
||||
parsed_tool_calls: list[llm.ToolInput] = []
|
||||
for tool_call in raw_tool_calls:
|
||||
function = getattr(tool_call, "function", None)
|
||||
name = getattr(function, "name", None) if function else None
|
||||
if not name:
|
||||
continue
|
||||
|
||||
arguments = getattr(function, "arguments", None) or {}
|
||||
if isinstance(arguments, Mapping):
|
||||
arguments_dict = dict(arguments)
|
||||
else:
|
||||
arguments_dict = {"raw": arguments}
|
||||
|
||||
parsed_tool_calls.append(llm.ToolInput(tool_name=name, tool_args=arguments_dict))
|
||||
|
||||
if parsed_tool_calls:
|
||||
tool_calls = parsed_tool_calls
|
||||
|
||||
if content is None and not tool_calls:
|
||||
return None, None
|
||||
# return openai formatted tool calls
|
||||
tool_calls = [{
|
||||
"function": {
|
||||
"name": call.function.name,
|
||||
"arguments": call.function.arguments,
|
||||
}
|
||||
} for call in raw_tool_calls]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
@@ -226,7 +213,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
keep_alive_payload = self._format_keep_alive(keep_alive)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
||||
client = self._build_client(timeout=timeout)
|
||||
try:
|
||||
stream = await client.chat(
|
||||
@@ -249,4 +236,4 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
except (ResponseError, ConnectionError) as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
@@ -5,7 +5,8 @@ import csv
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
|
||||
import re
|
||||
from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.components import conversation
|
||||
@@ -14,6 +15,7 @@ from homeassistant.components.homeassistant.exposed_entities import async_should
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr, entity
|
||||
from homeassistant.util import color
|
||||
@@ -184,28 +186,25 @@ class LocalLLMClient:
|
||||
_LOGGER.debug("Received chunk: %s", input_chunk)
|
||||
|
||||
tool_calls = input_chunk.tool_calls
|
||||
# fix tool calls for the service tool
|
||||
if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID:
|
||||
tool_calls = [
|
||||
llm.ToolInput(
|
||||
tool_name=SERVICE_TOOL_NAME,
|
||||
tool_args={**tc.tool_args, "service": tc.tool_name}
|
||||
) for tc in tool_calls
|
||||
]
|
||||
if tool_calls and not chat_log.llm_api:
|
||||
raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided")
|
||||
|
||||
yield conversation.AssistantContentDeltaDict(
|
||||
content=input_chunk.response,
|
||||
tool_calls=tool_calls
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
|
||||
|
||||
async def _async_parse_completion(
|
||||
self, llm_api: llm.APIInstance | None,
|
||||
async def _async_stream_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator."""
|
||||
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
@@ -236,7 +235,7 @@ class LocalLLMClient:
|
||||
cur_match_length = 0
|
||||
async for chunk in token_generator:
|
||||
# _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
|
||||
tool_calls: Optional[List[str | llm.ToolInput | dict]]
|
||||
tool_calls: Optional[List[str | dict]]
|
||||
content, tool_calls = chunk
|
||||
|
||||
if not tool_calls:
|
||||
@@ -289,25 +288,67 @@ class LocalLLMClient:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
for raw_tool_call in tool_calls:
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
|
||||
else:
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
|
||||
if len(parsed_tool_calls) > 0:
|
||||
result.tool_calls = parsed_tool_calls
|
||||
|
||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||
yield result
|
||||
|
||||
async def _async_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
completion: str | dict) -> TextGenerationResult:
|
||||
"""Parse completion with tool calls from the backend."""
|
||||
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL)
|
||||
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX)
|
||||
tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL)
|
||||
|
||||
if isinstance(completion, dict):
|
||||
completion = str(completion.get("response", ""))
|
||||
|
||||
# Remove thinking blocks, and extract tool calls
|
||||
tool_calls = tool_regex.findall(completion)
|
||||
completion = think_regex.sub("", completion)
|
||||
completion = tool_regex.sub("", completion)
|
||||
|
||||
to_say = ""
|
||||
parsed_tool_calls: list[llm.ToolInput] = []
|
||||
if len(tool_calls) and not llm_api:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
for raw_tool_call in tool_calls:
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
else:
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
|
||||
return TextGenerationResult(
|
||||
response=completion + (to_say or ""),
|
||||
tool_calls=parsed_tool_calls,
|
||||
)
|
||||
|
||||
def _async_get_all_exposed_domains(self) -> list[str]:
|
||||
"""Gather all exposed domains"""
|
||||
|
||||
@@ -32,7 +32,7 @@ from .const import (
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -275,26 +275,29 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe
|
||||
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ]
|
||||
|
||||
else:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in llm_api.tools ]
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
|
||||
result: List[ChatCompletionTool] = []
|
||||
|
||||
for tool in llm_api.tools:
|
||||
if tool.name == SERVICE_TOOL_NAME:
|
||||
result.extend([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ])
|
||||
else:
|
||||
result.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@@ -396,41 +399,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
|
||||
def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
else:
|
||||
schema_to_validate = vol.Schema({
|
||||
# try to validate either format
|
||||
is_services_tool_call = False
|
||||
try:
|
||||
base_schema_to_validate = vol.Schema({
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): vol.Union(str, dict),
|
||||
})
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
base_schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
try:
|
||||
home_llm_schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
home_llm_schema_to_validate(parsed_tool_call)
|
||||
is_services_tool_call = True
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", ""))
|
||||
args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
|
||||
tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
|
||||
|
||||
if isinstance(args_dict, str):
|
||||
if not args_dict.strip():
|
||||
|
||||
Reference in New Issue
Block a user