add initial implementation for ai task entities

This commit is contained in:
Alex O'Connell
2025-10-26 21:47:23 -04:00
parent ca6050b6d5
commit 03989e37b5
11 changed files with 525 additions and 234 deletions

View File

@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,

View File

@@ -0,0 +1,126 @@
"""AI Task integration for Local LLMs."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from enum import StrEnum
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, Context
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient
from .const import (
CONF_PROMPT,
DEFAULT_PROMPT,
)
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry[LocalLLMClient],
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
config_subentry_id=subentry.subentry_id,
)
class ResultExtractionMethod(StrEnum):
NONE = "none"
STRUCTURED_OUTPUT = "structure"
TOOL = "tool"
class LocalLLMTaskEntity(
ai_task.AITaskEntity,
LocalLLMEntity,
):
"""Ollama AI Task entity."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize Ollama AI Task entity."""
super().__init__(*args, **kwargs)
if self.client._supports_vision(self.runtime_options):
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA |
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
else:
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
extraction_method = ResultExtractionMethod.NONE
try:
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
message_history = chat_log.content[:]
if not isinstance(message_history[0], conversation.SystemContent):
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
message_history.insert(0, system_prompt)
_LOGGER.debug(f"Generating response for {task.name=}...")
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
assistant_message = await anext(generation_result)
if not isinstance(assistant_message, conversation.AssistantContent):
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
text = assistant_message.content
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
if extraction_method == ResultExtractionMethod.NONE:
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM structured response") from err
elif extraction_method == ResultExtractionMethod.TOOL:
try:
data = assistant_message.tool_calls[0].tool_args
except (IndexError, AttributeError) as err:
_LOGGER.error(
"Failed to extract tool arguments from response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM tool response") from err
else:
raise ValueError() # should not happen
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
except Exception as err:
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err

View File

@@ -119,7 +119,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
def _generate_stream(self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
user_input: conversation.ConversationInput,
agent_id: str,
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options[CONF_CHAT_MODEL]
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
@@ -128,7 +128,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
enable_legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
endpoint, additional_params = self._chat_completion_params(entity_options)
messages = get_oai_formatted_messages(conversation)
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
request_params = {
"model": model_name,
@@ -175,7 +175,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
break
if chunk and chunk.strip():
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, user_input)
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
if to_say or tool_calls:
yield to_say, tool_calls
except asyncio.TimeoutError as err:
@@ -183,14 +183,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
except aiohttp.ClientError as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/chat/completions"
return endpoint, request_params
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
list(response_json.keys()), response_json)
@@ -207,7 +207,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
tool_calls = []
for call in choice["delta"]["tool_calls"]:
tool_call, to_say = parse_raw_tool_call(
call["function"], llm_api, user_input)
call["function"], llm_api, agent_id)
if tool_call:
tool_calls.append(tool_call)
@@ -366,7 +366,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
async def _generate(self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
user_input: conversation.ConversationInput,
agent_id: str,
entity_options: dict[str, Any]) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""

View File

@@ -419,7 +419,7 @@ class LlamaCppClient(LocalLLMClient):
def _generate_stream(self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
user_input: conversation.ConversationInput,
agent_id: str,
entity_options: dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator that yields TextGenerationResult as tokens are produced."""
@@ -434,7 +434,7 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug(f"Options: {entity_options}")
messages = get_oai_formatted_messages(conversation)
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
tools = None
if llm_api:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
@@ -464,5 +464,5 @@ class LlamaCppClient(LocalLLMClient):
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
yield content, tool_calls
return self._async_parse_completion(llm_api, user_input, entity_options, next_token=next_token())
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())

View File

@@ -135,7 +135,7 @@ class OllamaAPIClient(LocalLLMClient):
return response, tool_calls
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options.get(CONF_CHAT_MODEL, "")
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -199,4 +199,4 @@ class OllamaAPIClient(LocalLLMClient):
except aiohttp.ClientError as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())

View File

@@ -400,6 +400,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Return subentries supported by this integration."""
return {
"conversation": LocalLLMSubentryFlowHandler,
"ai_task_data": LocalLLMSubentryFlowHandler,
}
@@ -836,7 +837,7 @@ def local_llama_config_option_schema(
})
elif backend_type == BACKEND_TYPE_OLLAMA:
result.update({
vol.Required(
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
@@ -924,6 +925,8 @@ def local_llama_config_option_schema(
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
elif subentry_type == "ai_task_data":
pass # no additional options for ai_task_data for now
# sort the options
global_order = [

View File

@@ -1,6 +1,6 @@
"""Defines the various LLM Backend Agents"""
from __future__ import annotations
from typing import Literal
from typing import Literal, List, Tuple, Any
import logging
from homeassistant.components.conversation import ConversationInput, ConversationResult, ConversationEntity
@@ -9,11 +9,24 @@ from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.helpers import chat_session
from homeassistant.exceptions import TemplateError, HomeAssistantError
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from custom_components.llama_conversation.utils import MalformedToolCallException
from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry
from .const import (
CONF_PROMPT,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_MAX_TOOL_CALL_ITERATIONS,
DEFAULT_PROMPT,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
DOMAIN,
)
@@ -73,4 +86,135 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent, LocalLLMEntit
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self.client._async_handle_message(user_input, chat_log, self.runtime_options)
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
refresh_system_prompt = self.runtime_options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
remember_conversation = self.runtime_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.runtime_options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
max_tool_call_iterations = self.runtime_options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
llm_api: llm.APIInstance | None = None
if self.runtime_options.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
self.runtime_options[CONF_LLM_HASS_API],
llm_context=user_input.as_llm_context(DOMAIN)
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# ensure this chat log has the LLM API instance
chat_log.llm_api = llm_api
if remember_conversation:
message_history = chat_log.content[:]
else:
message_history = []
# trim message history before processing if necessary
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
new_message_history = [message_history[0]] # copy system prompt
new_message_history.extend(message_history[1:][-(remember_num_interactions * 2):])
# re-generate prompt if necessary
if len(message_history) == 0 or refresh_system_prompt:
try:
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, llm_api, self.runtime_options))
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
if len(message_history) == 0:
message_history.append(system_prompt)
else:
message_history[0] = system_prompt
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
# if max tool calls is 0 then we expect to generate the response & tool call in one go
for idx in range(max(1, max_tool_call_iterations)):
_LOGGER.debug(f"Generating response for {user_input.text=}, iteration {idx+1}/{max_tool_call_iterations}")
generation_result = await self.client._async_generate(message_history, user_input.agent_id, chat_log, self.runtime_options)
last_generation_had_tool_calls = False
while True:
try:
message = await anext(generation_result)
message_history.append(message)
_LOGGER.debug("Added message to history: %s", message)
if message.role == "assistant":
if message.tool_calls and len(message.tool_calls) > 0:
last_generation_had_tool_calls = True
else:
last_generation_had_tool_calls = False
except StopAsyncIteration:
break
except MalformedToolCallException as err:
message_history.extend(err.as_tool_messages())
last_generation_had_tool_calls = True
_LOGGER.debug("Malformed tool call produced", exc_info=err)
except Exception as err:
_LOGGER.exception("There was a problem talking to the backend")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# If not multi-turn, break after first tool call
# also break if no tool calls were made
if not last_generation_had_tool_calls:
break
# return an error if we run out of attempt without succeeding
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Sorry, I ran out of attempts to handle your request",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
if len(tool_calls) > 0:
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
tools_str = '\n'.join(str_tools)
intent_response.async_set_card(
title="Changes",
content=f"Ran the following tools:\n{tools_str}"
)
has_speech = False
for i in range(1, len(message_history)):
cur_msg = message_history[-1 * i]
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
intent_response.async_set_speech(cur_msg.content)
has_speech = True
break
if not has_speech:
intent_response.async_set_speech("I don't have anything to say right now")
_LOGGER.debug(message_history)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)

View File

@@ -8,45 +8,33 @@ import random
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
from dataclasses import dataclass
from homeassistant.components.conversation import ConversationInput, ConversationResult
from homeassistant.components import conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError, HomeAssistantError
from homeassistant.helpers import intent, template, entity_registry as er, llm, \
from homeassistant.helpers import template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr, entity
from homeassistant.util import color
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema, MalformedToolCallException
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema
from .const import (
CONF_CHAT_MODEL,
CONF_SELECTED_LANGUAGE,
CONF_PROMPT,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_THINKING_PREFIX,
CONF_THINKING_SUFFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOOL_CALL_SUFFIX,
CONF_ENABLE_LEGACY_TOOL_CALLING,
DEFAULT_PROMPT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
DOMAIN,
DEFAULT_THINKING_PREFIX,
DEFAULT_THINKING_SUFFIX,
@@ -145,27 +133,31 @@ class LocalLLMClient:
await self.hass.async_add_executor_job(
self._unload_model, entity_options
)
def _supports_vision(self, entity_options: dict[str, Any]) -> bool:
"""Determine if the backend supports vision inputs. Implemented by sub-classes"""
return False
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator for streaming responses. Subclasses should implement."""
raise NotImplementedError()
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> TextGenerationResult:
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: dict[str, Any]) -> TextGenerationResult:
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
raise NotImplementedError()
async def _async_generate(self, conv: List[conversation.Content], user_input: ConversationInput, chat_log: conversation.chat_log.ChatLog, entity_options: dict[str, Any]):
async def _async_generate(self, conv: List[conversation.Content], agent_id: str, chat_log: conversation.chat_log.ChatLog, entity_options: dict[str, Any]):
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
if hasattr(self, '_generate_stream'):
# Try to stream and collect the full response
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, user_input, entity_options), user_input, chat_log)
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, agent_id, entity_options), agent_id, chat_log)
# Fallback to "blocking" generate
blocking_result = await self._generate(conv, chat_log.llm_api, user_input, entity_options)
blocking_result = await self._generate(conv, chat_log.llm_api, agent_id, entity_options)
return chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
agent_id=agent_id,
content=blocking_result.response,
tool_calls=blocking_result.tool_calls
)
@@ -184,7 +176,7 @@ class LocalLLMClient:
async def _transform_result_stream(
self,
result: AsyncIterator[TextGenerationResult],
user_input: ConversationInput,
agent_id: str,
chat_log: conversation.chat_log.ChatLog
):
async def async_iterator():
@@ -205,152 +197,11 @@ class LocalLLMClient:
tool_calls=tool_calls
)
return chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
entity_options: Dict[str, Any],
) -> conversation.ConversationResult:
raw_prompt = entity_options.get(CONF_PROMPT, DEFAULT_PROMPT)
refresh_system_prompt = entity_options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
remember_conversation = entity_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = entity_options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
max_tool_call_iterations = entity_options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
llm_api: llm.APIInstance | None = None
if entity_options.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
entity_options[CONF_LLM_HASS_API],
llm_context=user_input.as_llm_context(DOMAIN)
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# ensure this chat log has the LLM API instance
chat_log.llm_api = llm_api
if remember_conversation:
message_history = chat_log.content[:]
else:
message_history = []
# trim message history before processing if necessary
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
new_message_history = [message_history[0]] # copy system prompt
new_message_history.extend(message_history[1:][-(remember_num_interactions * 2):])
# re-generate prompt if necessary
if len(message_history) == 0 or refresh_system_prompt:
try:
system_prompt = conversation.SystemContent(content=self._generate_system_prompt(raw_prompt, llm_api, entity_options))
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
if len(message_history) == 0:
message_history.append(system_prompt)
else:
message_history[0] = system_prompt
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
# if max tool calls is 0 then we expect to generate the response & tool call in one go
for idx in range(max(1, max_tool_call_iterations)):
_LOGGER.debug(f"Generating response for {user_input.text=}, iteration {idx+1}/{max_tool_call_iterations}")
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
last_generation_had_tool_calls = False
while True:
try:
message = await anext(generation_result)
message_history.append(message)
_LOGGER.debug("Added message to history: %s", message)
if message.role == "assistant":
if message.tool_calls and len(message.tool_calls) > 0:
last_generation_had_tool_calls = True
else:
last_generation_had_tool_calls = False
except StopAsyncIteration:
break
except MalformedToolCallException as err:
message_history.extend(err.as_tool_messages())
last_generation_had_tool_calls = True
_LOGGER.debug("Malformed tool call produced", exc_info=err)
except Exception as err:
_LOGGER.exception("There was a problem talking to the backend")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# If not multi-turn, break after first tool call
# also break if no tool calls were made
if not last_generation_had_tool_calls:
break
# return an error if we run out of attempt without succeeding
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Sorry, I ran out of attempts to handle your request",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
if len(tool_calls) > 0:
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
tools_str = '\n'.join(str_tools)
intent_response.async_set_card(
title="Changes",
content=f"Ran the following tools:\n{tools_str}"
)
has_speech = False
for i in range(1, len(message_history)):
cur_msg = message_history[-1 * i]
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
intent_response.async_set_speech(cur_msg.content)
has_speech = True
break
if not has_speech:
intent_response.async_set_speech("I don't have anything to say right now")
_LOGGER.debug(message_history)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
async def _async_parse_completion(
self, llm_api: llm.APIInstance | None,
user_input: ConversationInput,
agent_id: str,
entity_options: Dict[str, Any],
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
@@ -442,9 +293,9 @@ class LocalLLMClient:
parsed_tool_calls.append(raw_tool_call)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, user_input)
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)

View File

@@ -35,10 +35,10 @@
"config_subentries": {
"conversation": {
"initiate_flow": {
"user": "Add conversation agent",
"reconfigure": "Reconfigure conversation agent"
"user": "Add Conversation Agent",
"reconfigure": "Reconfigure Conversation Agent"
},
"entry_type": "Conversation agent",
"entry_type": "Conversation Agent",
"error": {
"download_failed": "The download failed to complete: {exception}",
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
@@ -176,6 +176,145 @@
"title": "Configure the selected model"
}
}
},
"ai_task_data": {
"initiate_flow": {
"user": "Add AI Task Handler",
"reconfigure": "Reconfigure AI Task Handler"
},
"entry_type": "AI Task Handler",
"error": {
"download_failed": "The download failed to complete: {exception}",
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
"no_supported_ggufs": "The provided HuggingFace repo does not contain any compatible GGUF files!",
"missing_model_api": "The selected model is not provided by this API. The available models have been populated in the dropdown.",
"missing_model_file": "The provided file does not exist.",
"other_existing_local": "Another model is already loaded locally. Please unload it or configure a remote model.",
"unknown": "Unexpected error",
"sys_refresh_caching_enabled": "System prompt refresh must be enabled for prompt caching to work!",
"missing_gbnf_file": "The GBNF file was not found: {filename}",
"missing_icl_file": "The in context learning example CSV file was not found: {filename}"
},
"progress": {
"download": "Please wait while the model is being downloaded from HuggingFace. This can take a few minutes."
},
"abort": {
"reconfigure_successful": "Successfully updated model options."
},
"step": {
"pick_model": {
"data": {
"huggingface_model": "Model Name",
"downloaded_model_file": "Local file name",
"downloaded_model_quantization": "Downloaded model quantization"
},
"description": "Select a model to use. \n\n**Models supported out of the box:**\n1. [Home LLM](https://huggingface.co/collections/acon96/home-llm-6618762669211da33bb22c5a): Home 3B & Home 1B\n2. Mistral: [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)\n3. Llama 3: [8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and [70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)",
"title": "Pick Model"
},
"model_parameters": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"prompt": "System Prompt",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"min_p": "Min P",
"typical_p": "Typical P",
"request_timeout": "Remote Request Timeout (seconds)",
"ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "(ollama) JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
"openai_api_key": "API Key",
"text_generation_webui_admin_key": "(text-generation-webui) Admin Key",
"service_call_regex": "Service Call Regex",
"in_context_examples": "Enable in context learning (ICL) examples",
"in_context_examples_file": "In context learning examples CSV filename",
"num_in_context_examples": "Number of ICL examples to generate",
"text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name",
"text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode",
"prompt_caching": "Enable Prompt Caching",
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
"context_length": "Context Length",
"batch_size": "(llama.cpp) Batch Size",
"n_threads": "(llama.cpp) Thread Count",
"n_batch_threads": "(llama.cpp) Batch Thread Count",
"thinking_prefix": "Reasoning Content Prefix",
"thinking_suffix": "Reasoning Content Suffix",
"tool_call_prefix": "Tool Call Prefix",
"tool_call_suffix": "Tool Call Suffix",
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
"max_tool_call_iterations": "Maximum Tool Call Attempts"
},
"data_description": {
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
},
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
"title": "Configure the selected model"
},
"reconfigure": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API",
"prompt": "System Prompt",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"min_p": "Min P",
"typical_p": "Typical P",
"request_timeout": "Remote Request Timeout (seconds)",
"ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "(ollama) JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
"openai_api_key": "API Key",
"text_generation_webui_admin_key": "(text-generation-webui) Admin Key",
"service_call_regex": "Service Call Regex",
"refresh_prompt_per_turn": "Refresh System Prompt Every Turn",
"remember_conversation": "Remember conversation",
"remember_num_interactions": "Number of past interactions to remember",
"in_context_examples": "Enable in context learning (ICL) examples",
"in_context_examples_file": "In context learning examples CSV filename",
"num_in_context_examples": "Number of ICL examples to generate",
"text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name",
"text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode",
"prompt_caching": "Enable Prompt Caching",
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
"context_length": "Context Length",
"batch_size": "(llama.cpp) Batch Size",
"n_threads": "(llama.cpp) Thread Count",
"n_batch_threads": "(llama.cpp) Batch Thread Count",
"thinking_prefix": "Reasoning Content Prefix",
"thinking_suffix": "Reasoning Content Suffix",
"tool_call_prefix": "Tool Call Prefix",
"tool_call_suffix": "Tool Call Suffix",
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
"max_tool_call_iterations": "Maximum Tool Call Attempts"
},
"data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
},
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
"title": "Configure the selected model"
}
}
}
},
"options": {

View File

@@ -9,11 +9,14 @@ import multiprocessing
import voluptuous as vol
import webcolors
import json
import base64
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple, cast
from webcolors import CSS3
from importlib.metadata import version
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.components import conversation
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import intent, llm, aiohttp_client
@@ -24,12 +27,12 @@ from homeassistant.util.package import install_package, is_installed
from voluptuous_openapi import convert
from .const import (
DOMAIN,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME
)
from typing import TYPE_CHECKING
@@ -296,48 +299,64 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
return result
def get_oai_formatted_messages(conversation: Sequence[conversation.Content], user_content_as_list: bool = False, tool_args_to_str: bool = True) -> List[ChatCompletionRequestMessage]:
messages: List[ChatCompletionRequestMessage] = []
for message in conversation:
if message.role == "system":
messages.append({
"role": "system",
"content": message.content
})
elif message.role == "user":
if user_content_as_list:
messages.append({
"role": "user",
"content": [{ "type": "text", "text": message.content }]
})
else:
messages.append({
"role": "user",
"content": message.content
})
elif message.role == "assistant":
if message.tool_calls:
messages.append({
"role": "assistant",
"content": str(message.content),
"tool_calls": [
{
"type" : "function",
"id": t.id,
"function": {
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
"name": t.tool_name,
}
} for t in message.tool_calls
]
})
elif message.role == "tool_result":
messages.append({
"role": "tool",
"content": json.dumps(message.tool_result),
"tool_call_id": message.tool_call_id
})
messages: List[ChatCompletionRequestMessage] = []
for message in conversation:
if message.role == "system":
messages.append({
"role": "system",
"content": message.content
})
elif message.role == "user":
images: list[str] = []
for attachment in message.attachments or ():
if not attachment.mime_type.startswith("image/"):
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="unsupported_attachment_type",
)
images.append(get_file_contents_base64(attachment.path))
return messages
if user_content_as_list:
content = [{ "type": "text", "text": message.content }]
for image in images:
content.append({ "type": "image_url", "image_url": {"url": image } })
messages.append({
"role": "user",
"content": content
})
else:
message = {
"role": "user",
"content": message.content
}
if images:
message["images"] = images
messages.append(message)
elif message.role == "assistant":
if message.tool_calls:
messages.append({
"role": "assistant",
"content": str(message.content),
"tool_calls": [
{
"type" : "function",
"id": t.id,
"function": {
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
"name": t.tool_name,
}
} for t in message.tool_calls
]
})
elif message.role == "tool_result":
messages.append({
"role": "tool",
"content": json.dumps(message.tool_result),
"tool_call_id": message.tool_call_id
})
return messages
def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dict[str, Any]]:
service_dict = llm_api.api.hass.services.async_services()
@@ -377,7 +396,7 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
return tools
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_input: conversation.ConversationInput) -> tuple[llm.ToolInput | None, str | None]:
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
@@ -407,7 +426,7 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(user_input.agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
@@ -420,7 +439,7 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
try:
args_dict = json.loads(args_dict)
except json.JSONDecodeError:
raise MalformedToolCallException(user_input.agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
raise MalformedToolCallException(agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
# make sure brightness is 0-255 and not a percentage
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
@@ -477,4 +496,13 @@ def is_valid_hostname(host: str) -> bool:
domain_pattern = re.compile(r"^[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?)*\.[a-z]{2,}$")
return bool(domain_pattern.match(host))
return bool(domain_pattern.match(host))
def get_file_contents_base64(file_path: Path) -> str:
"""Reads a file and returns its contents encoded in base64."""
with open(file_path, "rb") as f:
encoded_bytes = base64.b64encode(f.read())
encoded_str = encoded_bytes.decode('utf-8')
return encoded_str

View File

@@ -1,5 +1,5 @@
# types from Home Assistant
homeassistant>=2024.8.3
homeassistant>=2025.8.3
hassil
home-assistant-intents