Files
home-llm/custom_components/llama_conversation/ai_task.py
2025-12-14 02:30:58 -05:00

270 lines
9.9 KiB
Python

"""AI Task integration for Local LLMs."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from enum import StrEnum
from typing import Any, cast
import voluptuous as vol
from voluptuous_openapi import convert as convert_to_openapi
from homeassistant.helpers import llm
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient
from .const import (
CONF_RESPONSE_JSON_SCHEMA,
CONF_AI_TASK_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry[LocalLLMClient],
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
config_subentry_id=subentry.subentry_id,
)
class ResultExtractionMethod(StrEnum):
NONE = "none"
STRUCTURED_OUTPUT = "structure"
TOOL = "tool"
class SubmitResponseTool(llm.Tool):
name = "submit_response"
description = "Submit the structured response payload for the AI task"
def __init__(self, parameters_schema: vol.Schema):
self.parameters = parameters_schema
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> dict:
return tool_input.tool_args or {}
class SubmitResponseAPI(llm.API):
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
self._tools = tools
super().__init__(
hass=hass,
id=f"{DOMAIN}-ai-task-tool",
name="AI Task Tool API",
)
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
return llm.APIInstance(
api=self,
api_prompt="Call submit_response to return the structured AI task result.",
llm_context=llm_context,
tools=self._tools,
)
class LocalLLMTaskEntity(
ai_task.AITaskEntity,
LocalLLMEntity,
):
"""AI Task entity."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize AI Task entity."""
super().__init__(*args, **kwargs)
if self.client._supports_vision(self.runtime_options):
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
else:
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _generate_once(
self,
message_history: list[conversation.Content],
chat_log: conversation.ChatLog,
entity_options: dict[str, Any],
) -> tuple[str, list | None]:
"""Generate a single response from the LLM."""
collected: list[str] = []
collected_tools = None
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
if hasattr(self.client, "_generate_stream"):
async for chunk in self.client._generate_stream(
message_history,
chat_log.llm_api,
self.entity_id,
entity_options,
):
if chunk.response:
collected.append(chunk.response)
if chunk.tool_calls:
collected_tools = chunk.tool_calls
else:
blocking_result = await self.client._generate(
message_history,
chat_log.llm_api,
self.entity_id,
entity_options,
)
if blocking_result.response:
collected.append(blocking_result.response)
if blocking_result.tool_calls:
collected_tools = blocking_result.tool_calls
text = "".join(collected).strip()
return text, collected_tools
def _extract_data(
self,
raw_text: str,
tool_calls: list | None,
extraction_method: ResultExtractionMethod,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Extract the final data from the LLM response based on the extraction method."""
if extraction_method == ResultExtractionMethod.NONE:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=raw_text,
)
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(raw_text)
except JSONDecodeError as err:
raise HomeAssistantError(
"Error with Local LLM structured response"
) from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
if extraction_method == ResultExtractionMethod.TOOL:
first_tool = (tool_calls or [None])[0]
if not first_tool or not getattr(first_tool, "tool_args", None):
raise HomeAssistantError("Error with Local LLM tool response")
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=first_tool.tool_args,
)
raise HomeAssistantError("Invalid extraction method for AI Task")
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
try:
task_prompt = self.runtime_options.get(CONF_AI_TASK_PROMPT, DEFAULT_AI_TASK_PROMPT)
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
max_attempts = retries + 1
entity_options = {**self.runtime_options}
if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure)
message_history = list(chat_log.content) if chat_log.content else []
system_message = conversation.SystemContent(content=task_prompt)
if message_history and isinstance(message_history[0], conversation.SystemContent):
message_history[0] = system_message
else:
message_history.insert(0, system_message)
if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
message_history.append(
conversation.UserContent(
content=task.instructions, attachments=task.attachments
)
)
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure:
raise HomeAssistantError(
"Structured extraction selected but no task structure was provided"
)
if extraction_method == ResultExtractionMethod.TOOL:
if not task.structure:
raise HomeAssistantError(
"Tool extraction selected but no task structure was provided"
)
parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA)
if isinstance(task.structure, dict):
parameters_schema = vol.Schema(task.structure)
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance(
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
)
last_error: Exception | None = None
for attempt in range(max_attempts):
try:
_LOGGER.debug(
"Generating response for %s (attempt %s/%s)...",
task.name,
attempt + 1,
max_attempts,
)
text, tool_calls = await self._generate_once(message_history, chat_log, entity_options)
return self._extract_data(text, tool_calls, extraction_method, chat_log)
except HomeAssistantError as err:
last_error = err
if attempt < max_attempts - 1:
continue
raise
except Exception as err:
last_error = err
_LOGGER.exception(
"Unhandled exception while running AI Task '%s'",
task.name,
)
raise HomeAssistantError(
f"Unhandled error while running AI Task '{task.name}'"
) from err
if last_error:
raise last_error
raise HomeAssistantError("AI Task generation failed without an error")
except Exception as err:
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
raise HomeAssistantError(
f"Unhandled error while running AI Task '{task.name}'"
) from err