mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
working ai task entities
This commit is contained in:
@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK)
|
||||
PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
|
||||
|
||||
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
|
||||
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
|
||||
|
||||
@@ -18,7 +18,6 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
|
||||
from .entity import LocalLLMEntity, LocalLLMClient
|
||||
from .const import (
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
@@ -41,13 +40,17 @@ async def async_setup_entry(
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
if subentry.subentry_type != ai_task.DOMAIN:
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
# create one entity per subentry
|
||||
ai_task_entity = LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)
|
||||
|
||||
# make sure model is loaded
|
||||
await config_entry.runtime_data._async_load_model(dict(subentry.data))
|
||||
|
||||
# register the ai task entity
|
||||
async_add_entities([ai_task_entity], config_subentry_id=subentry.subentry_id)
|
||||
|
||||
|
||||
class ResultExtractionMethod(StrEnum):
|
||||
@@ -88,6 +91,7 @@ class SubmitResponseAPI(llm.API):
|
||||
api_prompt="Call submit_response to return the structured AI task result.",
|
||||
llm_context=llm_context,
|
||||
tools=self._tools,
|
||||
custom_serializer=llm.selector_serializer,
|
||||
)
|
||||
|
||||
|
||||
@@ -152,9 +156,11 @@ class LocalLLMTaskEntity(
|
||||
tool_calls: list | None,
|
||||
extraction_method: ResultExtractionMethod,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Extract the final data from the LLM response based on the extraction method."""
|
||||
if extraction_method == ResultExtractionMethod.NONE:
|
||||
|
||||
if extraction_method == ResultExtractionMethod.NONE or structure is None:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=raw_text,
|
||||
@@ -176,6 +182,7 @@ class LocalLLMTaskEntity(
|
||||
first_tool = (tool_calls or [None])[0]
|
||||
if not first_tool or not getattr(first_tool, "tool_args", None):
|
||||
raise HomeAssistantError("Error with Local LLM tool response")
|
||||
structure(first_tool.tool_args) # validate against structure
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=first_tool.tool_args,
|
||||
@@ -197,7 +204,8 @@ class LocalLLMTaskEntity(
|
||||
|
||||
entity_options = {**self.runtime_options}
|
||||
if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure)
|
||||
_LOGGER.debug("Using structure for AI Task '%s': %s", task.name, task.structure)
|
||||
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure, custom_serializer=llm.selector_serializer)
|
||||
|
||||
message_history = list(chat_log.content) if chat_log.content else []
|
||||
|
||||
@@ -214,21 +222,8 @@ class LocalLLMTaskEntity(
|
||||
)
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure:
|
||||
raise HomeAssistantError(
|
||||
"Structured extraction selected but no task structure was provided"
|
||||
)
|
||||
if extraction_method == ResultExtractionMethod.TOOL:
|
||||
if not task.structure:
|
||||
raise HomeAssistantError(
|
||||
"Tool extraction selected but no task structure was provided"
|
||||
)
|
||||
|
||||
parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA)
|
||||
if isinstance(task.structure, dict):
|
||||
parameters_schema = vol.Schema(task.structure)
|
||||
|
||||
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance(
|
||||
if extraction_method == ResultExtractionMethod.TOOL and task.structure:
|
||||
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(task.structure)]).async_get_api_instance(
|
||||
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
|
||||
)
|
||||
|
||||
@@ -242,7 +237,7 @@ class LocalLLMTaskEntity(
|
||||
max_attempts,
|
||||
)
|
||||
text, tool_calls = await self._generate_once(message_history, chat_log, entity_options)
|
||||
return self._extract_data(text, tool_calls, extraction_method, chat_log)
|
||||
return self._extract_data(text, tool_calls, extraction_method, chat_log, task.structure)
|
||||
except HomeAssistantError as err:
|
||||
last_error = err
|
||||
if attempt < max_attempts - 1:
|
||||
|
||||
@@ -9,7 +9,6 @@ import time
|
||||
from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
@@ -303,7 +302,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
entity_ids = [
|
||||
state.entity_id for state in self.hass.states.async_all() \
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
_LOGGER.debug(f"watching entities: {entity_ids}")
|
||||
@@ -440,7 +439,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
_LOGGER.debug(f"Options: {entity_options}")
|
||||
|
||||
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
tools = None
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
@@ -10,6 +10,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
|
||||
from homeassistant.components import conversation, ai_task
|
||||
from homeassistant.data_entry_flow import (
|
||||
AbortFlow,
|
||||
)
|
||||
@@ -405,8 +406,8 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {
|
||||
"conversation": LocalLLMSubentryFlowHandler,
|
||||
"ai_task_data": LocalLLMSubentryFlowHandler,
|
||||
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
|
||||
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -590,7 +591,7 @@ def local_llama_config_option_schema(
|
||||
subentry_type: str,
|
||||
) -> dict:
|
||||
|
||||
is_ai_task = subentry_type == "ai_task_data"
|
||||
is_ai_task = subentry_type == ai_task.DOMAIN
|
||||
default_prompt = DEFAULT_AI_TASK_PROMPT if is_ai_task else build_prompt_template(language, DEFAULT_PROMPT)
|
||||
prompt_key = CONF_AI_TASK_PROMPT if is_ai_task else CONF_PROMPT
|
||||
prompt_selector = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)) if is_ai_task else TemplateSelector()
|
||||
@@ -914,7 +915,7 @@ def local_llama_config_option_schema(
|
||||
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
|
||||
})
|
||||
|
||||
if subentry_type == "conversation":
|
||||
if subentry_type == conversation.DOMAIN:
|
||||
apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
@@ -954,8 +955,8 @@ def local_llama_config_option_schema(
|
||||
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
): int,
|
||||
})
|
||||
elif subentry_type == "ai_task_data":
|
||||
# no extra conversation/tool options for ai_task_data beyond schema defaults
|
||||
elif subentry_type == ai_task.DOMAIN:
|
||||
# no extra conversation/tool options for ai_task
|
||||
pass
|
||||
|
||||
# sort the options
|
||||
@@ -1155,7 +1156,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
description_placeholders = {}
|
||||
entry = self._get_entry()
|
||||
backend_type = entry.data[CONF_BACKEND_TYPE]
|
||||
is_ai_task = self._subentry_type == "ai_task_data"
|
||||
is_ai_task = self._subentry_type == ai_task.DOMAIN
|
||||
|
||||
if is_ai_task:
|
||||
if CONF_AI_TASK_PROMPT not in self.model_config:
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIte
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
|
||||
@@ -352,7 +351,7 @@ class LocalLLMClient:
|
||||
"""Gather all exposed domains"""
|
||||
domains = set()
|
||||
for state in self.hass.states.async_all():
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
|
||||
domains.add(state.domain)
|
||||
|
||||
return list(domains)
|
||||
@@ -365,7 +364,7 @@ class LocalLLMClient:
|
||||
area_registry = ar.async_get(self.hass)
|
||||
|
||||
for state in self.hass.states.async_all():
|
||||
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
if not async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
|
||||
continue
|
||||
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"version": "0.4.4",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"dependencies": ["conversation", "ai_task"],
|
||||
"after_dependencies": ["assist_pipeline", "intent"],
|
||||
"documentation": "https://github.com/acon96/home-llm",
|
||||
"integration_type": "service",
|
||||
|
||||
@@ -177,7 +177,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"ai_task": {
|
||||
"initiate_flow": {
|
||||
"user": "Add AI Task Handler",
|
||||
"reconfigure": "Reconfigure AI Task Handler"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
@@ -82,7 +83,7 @@ async def test_structured_output_success(hass):
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"foo": int})
|
||||
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
@@ -97,7 +98,7 @@ async def test_structured_output_invalid_json_raises(hass):
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"foo": int})
|
||||
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
@@ -112,7 +113,7 @@ async def test_tool_extraction_success(hass):
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"value": int})
|
||||
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
@@ -133,21 +134,8 @@ async def test_tool_extraction_missing_tool_args_raises(hass):
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"value": int})
|
||||
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_requires_structure(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=None)
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
Reference in New Issue
Block a user