working ai task entities

This commit is contained in:
Alex O'Connell
2025-12-14 10:34:21 -05:00
parent b547da286f
commit 1f078d0a41
8 changed files with 39 additions and 57 deletions

View File

@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK)
PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,

View File

@@ -18,7 +18,6 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient
from .const import (
CONF_RESPONSE_JSON_SCHEMA,
@@ -41,13 +40,17 @@ async def async_setup_entry(
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
if subentry.subentry_type != ai_task.DOMAIN:
continue
async_add_entities(
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
config_subentry_id=subentry.subentry_id,
)
# create one entity per subentry
ai_task_entity = LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)
# make sure model is loaded
await config_entry.runtime_data._async_load_model(dict(subentry.data))
# register the ai task entity
async_add_entities([ai_task_entity], config_subentry_id=subentry.subentry_id)
class ResultExtractionMethod(StrEnum):
@@ -88,6 +91,7 @@ class SubmitResponseAPI(llm.API):
api_prompt="Call submit_response to return the structured AI task result.",
llm_context=llm_context,
tools=self._tools,
custom_serializer=llm.selector_serializer,
)
@@ -152,9 +156,11 @@ class LocalLLMTaskEntity(
tool_calls: list | None,
extraction_method: ResultExtractionMethod,
chat_log: conversation.ChatLog,
structure: vol.Schema | None,
) -> ai_task.GenDataTaskResult:
"""Extract the final data from the LLM response based on the extraction method."""
if extraction_method == ResultExtractionMethod.NONE:
if extraction_method == ResultExtractionMethod.NONE or structure is None:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=raw_text,
@@ -176,6 +182,7 @@ class LocalLLMTaskEntity(
first_tool = (tool_calls or [None])[0]
if not first_tool or not getattr(first_tool, "tool_args", None):
raise HomeAssistantError("Error with Local LLM tool response")
structure(first_tool.tool_args) # validate against structure
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=first_tool.tool_args,
@@ -197,7 +204,8 @@ class LocalLLMTaskEntity(
entity_options = {**self.runtime_options}
if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure)
_LOGGER.debug("Using structure for AI Task '%s': %s", task.name, task.structure)
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure, custom_serializer=llm.selector_serializer)
message_history = list(chat_log.content) if chat_log.content else []
@@ -214,21 +222,8 @@ class LocalLLMTaskEntity(
)
)
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure:
raise HomeAssistantError(
"Structured extraction selected but no task structure was provided"
)
if extraction_method == ResultExtractionMethod.TOOL:
if not task.structure:
raise HomeAssistantError(
"Tool extraction selected but no task structure was provided"
)
parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA)
if isinstance(task.structure, dict):
parameters_schema = vol.Schema(task.structure)
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance(
if extraction_method == ResultExtractionMethod.TOOL and task.structure:
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(task.structure)]).async_get_api_instance(
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
)
@@ -242,7 +237,7 @@ class LocalLLMTaskEntity(
max_attempts,
)
text, tool_calls = await self._generate_once(message_history, chat_log, entity_options)
return self._extract_data(text, tool_calls, extraction_method, chat_log)
return self._extract_data(text, tool_calls, extraction_method, chat_log, task.structure)
except HomeAssistantError as err:
last_error = err
if attempt < max_attempts - 1:

View File

@@ -9,7 +9,6 @@ import time
from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast
from homeassistant.components import conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
@@ -303,7 +302,7 @@ class LlamaCppClient(LocalLLMClient):
entity_ids = [
state.entity_id for state in self.hass.states.async_all() \
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
_LOGGER.debug(f"watching entities: {entity_ids}")
@@ -440,7 +439,7 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug(f"Options: {entity_options}")
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
messages = get_oai_formatted_messages(conversation)
tools = None
if llm_api:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())

View File

@@ -10,6 +10,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
from homeassistant.components import conversation, ai_task
from homeassistant.data_entry_flow import (
AbortFlow,
)
@@ -405,8 +406,8 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {
"conversation": LocalLLMSubentryFlowHandler,
"ai_task_data": LocalLLMSubentryFlowHandler,
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
}
@@ -590,7 +591,7 @@ def local_llama_config_option_schema(
subentry_type: str,
) -> dict:
is_ai_task = subentry_type == "ai_task_data"
is_ai_task = subentry_type == ai_task.DOMAIN
default_prompt = DEFAULT_AI_TASK_PROMPT if is_ai_task else build_prompt_template(language, DEFAULT_PROMPT)
prompt_key = CONF_AI_TASK_PROMPT if is_ai_task else CONF_PROMPT
prompt_selector = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)) if is_ai_task else TemplateSelector()
@@ -914,7 +915,7 @@ def local_llama_config_option_schema(
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
})
if subentry_type == "conversation":
if subentry_type == conversation.DOMAIN:
apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
@@ -954,8 +955,8 @@ def local_llama_config_option_schema(
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
elif subentry_type == "ai_task_data":
# no extra conversation/tool options for ai_task_data beyond schema defaults
elif subentry_type == ai_task.DOMAIN:
# no extra conversation/tool options for ai_task
pass
# sort the options
@@ -1155,7 +1156,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE]
is_ai_task = self._subentry_type == "ai_task_data"
is_ai_task = self._subentry_type == ai_task.DOMAIN
if is_ai_task:
if CONF_AI_TASK_PROMPT not in self.model_config:

View File

@@ -10,7 +10,6 @@ from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIte
from dataclasses import dataclass
from homeassistant.components import conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
@@ -352,7 +351,7 @@ class LocalLLMClient:
"""Gather all exposed domains"""
domains = set()
for state in self.hass.states.async_all():
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
domains.add(state.domain)
return list(domains)
@@ -365,7 +364,7 @@ class LocalLLMClient:
area_registry = ar.async_get(self.hass)
for state in self.hass.states.async_all():
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
if not async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
continue
entity = entity_registry.async_get(state.entity_id)

View File

@@ -4,7 +4,7 @@
"version": "0.4.4",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],
"dependencies": ["conversation", "ai_task"],
"after_dependencies": ["assist_pipeline", "intent"],
"documentation": "https://github.com/acon96/home-llm",
"integration_type": "service",

View File

@@ -177,7 +177,7 @@
}
}
},
"ai_task_data": {
"ai_task": {
"initiate_flow": {
"user": "Add AI Task Handler",
"reconfigure": "Reconfigure AI Task Handler"

View File

@@ -3,6 +3,7 @@
from typing import Any, cast
import pytest
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
@@ -82,7 +83,7 @@ async def test_structured_output_success(hass):
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"foo": int})
task = DummyGenTask(structure=vol.Schema({"foo": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@@ -97,7 +98,7 @@ async def test_structured_output_invalid_json_raises(hass):
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"foo": int})
task = DummyGenTask(structure=vol.Schema({"foo": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@@ -112,7 +113,7 @@ async def test_tool_extraction_success(hass):
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"value": int})
task = DummyGenTask(structure=vol.Schema({"value": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@@ -133,21 +134,8 @@ async def test_tool_extraction_missing_tool_args_raises(hass):
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"value": int})
task = DummyGenTask(structure=vol.Schema({"value": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_requires_structure(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=None)
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))