diff --git a/.gitignore b/.gitignore index 57f4fe5..5a0ecbf 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,8 @@ loras/ core/ config/ .DS_Store -data/*.json -data/*.jsonl +data/**/*.json +data/**/*.jsonl *.pyc main.log .venv diff --git a/custom_components/llama_conversation/backends/generic_openai.py b/custom_components/llama_conversation/backends/generic_openai.py index c38afa6..6d1631c 100644 --- a/custom_components/llama_conversation/backends/generic_openai.py +++ b/custom_components/llama_conversation/backends/generic_openai.py @@ -155,7 +155,7 @@ class GenericOpenAIAPIClient(LocalLLMClient): session = async_get_clientsession(self.hass) - async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]: + async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]: response = None chunk = None try: @@ -175,7 +175,7 @@ class GenericOpenAIAPIClient(LocalLLMClient): break if chunk and chunk.strip(): - to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id) + to_say, tool_calls = self._extract_response(json.loads(chunk)) if to_say or tool_calls: yield to_say, tool_calls except asyncio.TimeoutError as err: @@ -183,14 +183,14 @@ class GenericOpenAIAPIClient(LocalLLMClient): except aiohttp.ClientError as err: raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err - return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) + return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]: request_params = {} endpoint = "/chat/completions" return endpoint, request_params - def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]: + def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]: if "choices" not in response_json or len(response_json["choices"]) == 0: # finished _LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s", list(response_json.keys()), response_json) @@ -204,16 +204,7 @@ class GenericOpenAIAPIClient(LocalLLMClient): elif response_json["object"] == "chat.completion.chunk": response_text = choice["delta"].get("content", "") if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None: - tool_calls = [] - for call in choice["delta"]["tool_calls"]: - tool_call, to_say = parse_raw_tool_call( - call["function"], llm_api, agent_id) - - if tool_call: - tool_calls.append(tool_call) - - if to_say: - response_text += to_say + tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]] streamed = True else: response_text = choice["text"] @@ -398,7 +389,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient): try: text = self._extract_response(response_json) - return TextGenerationResult(response=text, response_streamed=False) + if not text: + return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.") + # return await self._async_parse_completion(llm_api, agent_id, entity_options, text) + return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output except Exception as err: _LOGGER.exception("Failed to parse Responses API payload: %s", err) return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}") diff --git a/custom_components/llama_conversation/backends/llamacpp.py b/custom_components/llama_conversation/backends/llamacpp.py index d347357..7a494bf 100644 --- a/custom_components/llama_conversation/backends/llamacpp.py +++ b/custom_components/llama_conversation/backends/llamacpp.py @@ -464,5 +464,5 @@ class LlamaCppClient(LocalLLMClient): tool_calls = chunk["choices"][0]["delta"].get("tool_calls") yield content, tool_calls - return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token()) + return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token()) diff --git a/custom_components/llama_conversation/backends/ollama.py b/custom_components/llama_conversation/backends/ollama.py index 0bdb177..e62d40d 100644 --- a/custom_components/llama_conversation/backends/ollama.py +++ b/custom_components/llama_conversation/backends/ollama.py @@ -155,33 +155,20 @@ class OllamaAPIClient(LocalLLMClient): return models - def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]: - message = getattr(response_chunk, "message", None) - content = getattr(message, "content", None) if message else None - raw_tool_calls = getattr(message, "tool_calls", None) if message else None + def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]: + content = response_chunk.message.content + raw_tool_calls = response_chunk.message.tool_calls - tool_calls: Optional[List[llm.ToolInput]] = None if raw_tool_calls: - parsed_tool_calls: list[llm.ToolInput] = [] - for tool_call in raw_tool_calls: - function = getattr(tool_call, "function", None) - name = getattr(function, "name", None) if function else None - if not name: - continue - - arguments = getattr(function, "arguments", None) or {} - if isinstance(arguments, Mapping): - arguments_dict = dict(arguments) - else: - arguments_dict = {"raw": arguments} - - parsed_tool_calls.append(llm.ToolInput(tool_name=name, tool_args=arguments_dict)) - - if parsed_tool_calls: - tool_calls = parsed_tool_calls - - if content is None and not tool_calls: - return None, None + # return openai formatted tool calls + tool_calls = [{ + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + } + } for call in raw_tool_calls] + else: + tool_calls = None return content, tool_calls @@ -226,7 +213,7 @@ class OllamaAPIClient(LocalLLMClient): tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains()) keep_alive_payload = self._format_keep_alive(keep_alive) - async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]: + async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]: client = self._build_client(timeout=timeout) try: stream = await client.chat( @@ -249,4 +236,4 @@ class OllamaAPIClient(LocalLLMClient): except (ResponseError, ConnectionError) as err: raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err - return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) + return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) diff --git a/custom_components/llama_conversation/entity.py b/custom_components/llama_conversation/entity.py index 8e70747..773f6d3 100644 --- a/custom_components/llama_conversation/entity.py +++ b/custom_components/llama_conversation/entity.py @@ -5,7 +5,8 @@ import csv import logging import os import random -from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator +import re +from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator from dataclasses import dataclass from homeassistant.components import conversation @@ -14,6 +15,7 @@ from homeassistant.components.homeassistant.exposed_entities import async_should from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import template, entity_registry as er, llm, \ area_registry as ar, device_registry as dr, entity from homeassistant.util import color @@ -184,28 +186,25 @@ class LocalLLMClient: _LOGGER.debug("Received chunk: %s", input_chunk) tool_calls = input_chunk.tool_calls - # fix tool calls for the service tool - if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID: - tool_calls = [ - llm.ToolInput( - tool_name=SERVICE_TOOL_NAME, - tool_args={**tc.tool_args, "service": tc.tool_name} - ) for tc in tool_calls - ] + if tool_calls and not chat_log.llm_api: + raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided") + yield conversation.AssistantContentDeltaDict( content=input_chunk.response, - tool_calls=tool_calls + tool_calls=tool_calls ) return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator()) - async def _async_parse_completion( - self, llm_api: llm.APIInstance | None, + async def _async_stream_parse_completion( + self, + llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any], - next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None, - anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None, + next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None, + anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None, ) -> AsyncGenerator[TextGenerationResult, None]: + """Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator.""" think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX) think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX) tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX) @@ -236,7 +235,7 @@ class LocalLLMClient: cur_match_length = 0 async for chunk in token_generator: # _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}") - tool_calls: Optional[List[str | llm.ToolInput | dict]] + tool_calls: Optional[List[str | dict]] content, tool_calls = chunk if not tool_calls: @@ -289,25 +288,67 @@ class LocalLLMClient: _LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls") else: for raw_tool_call in tool_calls: - if isinstance(raw_tool_call, llm.ToolInput): - parsed_tool_calls.append(raw_tool_call) + if isinstance(raw_tool_call, str): + tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id) else: - if isinstance(raw_tool_call, str): - tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id) - else: - tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id) + tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id) - if tool_call: - _LOGGER.debug("Tool call parsed: %s", tool_call) - parsed_tool_calls.append(tool_call) - if to_say: - result.response = to_say + if tool_call: + _LOGGER.debug("Tool call parsed: %s", tool_call) + parsed_tool_calls.append(tool_call) + if to_say: + result.response = to_say if len(parsed_tool_calls) > 0: result.tool_calls = parsed_tool_calls if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls): yield result + + async def _async_parse_completion( + self, + llm_api: llm.APIInstance | None, + agent_id: str, + entity_options: Dict[str, Any], + completion: str | dict) -> TextGenerationResult: + """Parse completion with tool calls from the backend.""" + think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX) + think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX) + think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL) + tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX) + tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX) + tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL) + + if isinstance(completion, dict): + completion = str(completion.get("response", "")) + + # Remove thinking blocks, and extract tool calls + tool_calls = tool_regex.findall(completion) + completion = think_regex.sub("", completion) + completion = tool_regex.sub("", completion) + + to_say = "" + parsed_tool_calls: list[llm.ToolInput] = [] + if len(tool_calls) and not llm_api: + _LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls") + else: + for raw_tool_call in tool_calls: + if isinstance(raw_tool_call, llm.ToolInput): + parsed_tool_calls.append(raw_tool_call) + else: + if isinstance(raw_tool_call, str): + tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id) + else: + tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id) + + if tool_call: + _LOGGER.debug("Tool call parsed: %s", tool_call) + parsed_tool_calls.append(tool_call) + + return TextGenerationResult( + response=completion + (to_say or ""), + tool_calls=parsed_tool_calls, + ) def _async_get_all_exposed_domains(self) -> list[str]: """Gather all exposed domains""" diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index aeeada7..fd124eb 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -32,7 +32,7 @@ from .const import ( ALLOWED_SERVICE_CALL_ARGUMENTS, SERVICE_TOOL_ALLOWED_SERVICES, SERVICE_TOOL_ALLOWED_DOMAINS, - HOME_LLM_API_ID, + SERVICE_TOOL_NAME, ) from typing import TYPE_CHECKING @@ -275,26 +275,29 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe def format_url(*, hostname: str, port: str, ssl: bool, path: str): return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}" -def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]: - if llm_api.api.id == HOME_LLM_API_ID: - result: List[ChatCompletionTool] = [ { - "type": "function", - "function": { - "name": tool["name"], - "description": f"Call the Home Assistant service '{tool['name']}'", - "parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer) - } - } for tool in get_home_llm_tools(llm_api, domains) ] - - else: - result: List[ChatCompletionTool] = [ { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description or "", - "parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer) - } - } for tool in llm_api.tools ] +def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]: + result: List[ChatCompletionTool] = [] + + for tool in llm_api.tools: + if tool.name == SERVICE_TOOL_NAME: + result.extend([{ + "type": "function", + "function": { + "name": tool["name"], + "description": f"Call the Home Assistant service '{tool['name']}'", + "parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer) + } + } for tool in get_home_llm_tools(llm_api, domains) ]) + else: + result.append({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer) + } + }) + return result @@ -396,41 +399,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic return tools -def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]: +def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]: if isinstance(raw_block, dict): parsed_tool_call = raw_block else: parsed_tool_call: dict = json.loads(raw_block) - if llm_api.api.id == HOME_LLM_API_ID: - schema_to_validate = vol.Schema({ - vol.Required('service'): str, - vol.Required('target_device'): str, - vol.Optional('rgb_color'): str, - vol.Optional('brightness'): vol.Coerce(float), - vol.Optional('temperature'): vol.Coerce(float), - vol.Optional('humidity'): vol.Coerce(float), - vol.Optional('fan_mode'): str, - vol.Optional('hvac_mode'): str, - vol.Optional('preset_mode'): str, - vol.Optional('duration'): str, - vol.Optional('item'): str, - }) - else: - schema_to_validate = vol.Schema({ + # try to validate either format + is_services_tool_call = False + try: + base_schema_to_validate = vol.Schema({ vol.Required("name"): str, vol.Required("arguments"): vol.Union(str, dict), }) - - try: - schema_to_validate(parsed_tool_call) + base_schema_to_validate(parsed_tool_call) except vol.Error as ex: - _LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}") - raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted") + try: + home_llm_schema_to_validate = vol.Schema({ + vol.Required('service'): str, + vol.Required('target_device'): str, + vol.Optional('rgb_color'): str, + vol.Optional('brightness'): vol.Coerce(float), + vol.Optional('temperature'): vol.Coerce(float), + vol.Optional('humidity'): vol.Coerce(float), + vol.Optional('fan_mode'): str, + vol.Optional('hvac_mode'): str, + vol.Optional('preset_mode'): str, + vol.Optional('duration'): str, + vol.Optional('item'): str, + }) + home_llm_schema_to_validate(parsed_tool_call) + is_services_tool_call = True + except vol.Error as ex: + _LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}") + raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted") # try to fix certain arguments - args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"] - tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", "")) + args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"] + tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"] if isinstance(args_dict, str): if not args_dict.strip():