support structured ouput for AI tasks

This commit is contained in:
Alex O'Connell
2025-12-14 01:31:39 -05:00
parent 6010bdf26c
commit b547da286f
9 changed files with 534 additions and 141 deletions

3
.gitignore vendored
View File

@@ -11,4 +11,5 @@ main.log
*.xlsx
notes.txt
runpod_bootstrap.sh
*.code-workspace
*.code-workspace
.coverage

View File

@@ -5,18 +5,30 @@ from __future__ import annotations
from json import JSONDecodeError
import logging
from enum import StrEnum
from typing import Any, cast
import voluptuous as vol
from voluptuous_openapi import convert as convert_to_openapi
from homeassistant.helpers import llm
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, Context
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient
from .const import (
CONF_PROMPT,
DEFAULT_PROMPT,
CONF_RESPONSE_JSON_SCHEMA,
CONF_AI_TASK_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
@@ -37,90 +49,221 @@ async def async_setup_entry(
config_subentry_id=subentry.subentry_id,
)
class ResultExtractionMethod(StrEnum):
NONE = "none"
STRUCTURED_OUTPUT = "structure"
TOOL = "tool"
class SubmitResponseTool(llm.Tool):
name = "submit_response"
description = "Submit the structured response payload for the AI task"
def __init__(self, parameters_schema: vol.Schema):
self.parameters = parameters_schema
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> dict:
return tool_input.tool_args or {}
class SubmitResponseAPI(llm.API):
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
self._tools = tools
super().__init__(
hass=hass,
id=f"{DOMAIN}-ai-task-tool",
name="AI Task Tool API",
)
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
return llm.APIInstance(
api=self,
api_prompt="Call submit_response to return the structured AI task result.",
llm_context=llm_context,
tools=self._tools,
)
class LocalLLMTaskEntity(
ai_task.AITaskEntity,
LocalLLMEntity,
):
"""Ollama AI Task entity."""
"""AI Task entity."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize Ollama AI Task entity."""
"""Initialize AI Task entity."""
super().__init__(*args, **kwargs)
if self.client._supports_vision(self.runtime_options):
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA |
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
else:
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _generate_once(
self,
message_history: list[conversation.Content],
chat_log: conversation.ChatLog,
entity_options: dict[str, Any],
) -> tuple[str, list | None]:
"""Generate a single response from the LLM."""
collected: list[str] = []
collected_tools = None
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
if hasattr(self.client, "_generate_stream"):
async for chunk in self.client._generate_stream(
message_history,
chat_log.llm_api,
self.entity_id,
entity_options,
):
if chunk.response:
collected.append(chunk.response)
if chunk.tool_calls:
collected_tools = chunk.tool_calls
else:
blocking_result = await self.client._generate(
message_history,
chat_log.llm_api,
self.entity_id,
entity_options,
)
if blocking_result.response:
collected.append(blocking_result.response)
if blocking_result.tool_calls:
collected_tools = blocking_result.tool_calls
text = "".join(collected).strip()
return text, collected_tools
def _extract_data(
self,
raw_text: str,
tool_calls: list | None,
extraction_method: ResultExtractionMethod,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Extract the final data from the LLM response based on the extraction method."""
if extraction_method == ResultExtractionMethod.NONE:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=raw_text,
)
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(raw_text)
except JSONDecodeError as err:
raise HomeAssistantError(
"Error with Local LLM structured response"
) from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
if extraction_method == ResultExtractionMethod.TOOL:
first_tool = (tool_calls or [None])[0]
if not first_tool or not getattr(first_tool, "tool_args", None):
raise HomeAssistantError("Error with Local LLM tool response")
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=first_tool.tool_args,
)
raise HomeAssistantError("Invalid extraction method for AI Task")
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
extraction_method = ResultExtractionMethod.NONE
try:
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
task_prompt = self.runtime_options.get(CONF_AI_TASK_PROMPT, DEFAULT_AI_TASK_PROMPT)
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
max_attempts = retries + 1
message_history = chat_log.content[:]
entity_options = {**self.runtime_options}
if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure)
if not isinstance(message_history[0], conversation.SystemContent):
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
message_history.insert(0, system_prompt)
_LOGGER.debug(f"Generating response for {task.name=}...")
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
message_history = list(chat_log.content) if chat_log.content else []
assistant_message = await anext(generation_result)
if not isinstance(assistant_message, conversation.AssistantContent):
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
text = assistant_message.content
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
if extraction_method == ResultExtractionMethod.NONE:
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM structured response") from err
elif extraction_method == ResultExtractionMethod.TOOL:
try:
data = assistant_message.tool_calls[0].tool_args
except (IndexError, AttributeError) as err:
_LOGGER.error(
"Failed to extract tool arguments from response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM tool response") from err
system_message = conversation.SystemContent(content=task_prompt)
if message_history and isinstance(message_history[0], conversation.SystemContent):
message_history[0] = system_message
else:
raise ValueError() # should not happen
message_history.insert(0, system_message)
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
message_history.append(
conversation.UserContent(
content=task.instructions, attachments=task.attachments
)
)
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure:
raise HomeAssistantError(
"Structured extraction selected but no task structure was provided"
)
if extraction_method == ResultExtractionMethod.TOOL:
if not task.structure:
raise HomeAssistantError(
"Tool extraction selected but no task structure was provided"
)
parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA)
if isinstance(task.structure, dict):
parameters_schema = vol.Schema(task.structure)
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance(
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
)
last_error: Exception | None = None
for attempt in range(max_attempts):
try:
_LOGGER.debug(
"Generating response for %s (attempt %s/%s)...",
task.name,
attempt + 1,
max_attempts,
)
text, tool_calls = await self._generate_once(message_history, chat_log, entity_options)
return self._extract_data(text, tool_calls, extraction_method, chat_log)
except HomeAssistantError as err:
last_error = err
if attempt < max_attempts - 1:
continue
raise
except Exception as err:
last_error = err
_LOGGER.exception(
"Unhandled exception while running AI Task '%s'",
task.name,
)
raise HomeAssistantError(
f"Unhandled error while running AI Task '{task.name}'"
) from err
if last_error:
raise last_error
raise HomeAssistantError("AI Task generation failed without an error")
except Exception as err:
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
raise HomeAssistantError(
f"Unhandled error while running AI Task '{task.name}'"
) from err

View File

@@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import (
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_GENERIC_OPENAI_PATH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
@@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient):
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
entity_options: dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options[CONF_CHAT_MODEL]
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
@@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient):
"messages": messages
}
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
tools = None
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
# most local backends absolutely butcher any sort of prompt formatting when using tool calling
@@ -258,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try:
if msg.role == "user":
input_text = msg.content
break
except Exception:
continue
@@ -358,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: dict[str, Any]) -> TextGenerationResult:
entity_options: dict[str, Any],
) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
model_name = entity_options.get(CONF_CHAT_MODEL)
@@ -368,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
request_params: Dict[str, Any] = {
"model": model_name,
}
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
request_params.update(additional_params)
headers: Dict[str, Any] = {}

View File

@@ -57,6 +57,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DOMAIN,
CONF_RESPONSE_JSON_SCHEMA,
)
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
@@ -64,10 +65,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
from typing import TYPE_CHECKING
from types import ModuleType
if TYPE_CHECKING:
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType
from llama_cpp import (
Llama as LlamaType,
LlamaGrammar as LlamaGrammarType,
ChatCompletionRequestResponseFormat
)
else:
LlamaType = Any
LlamaGrammarType = Any
ChatCompletionRequestResponseFormat = Any
_LOGGER = logging.getLogger(__name__)
@@ -441,6 +447,14 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
response_format: Optional[ChatCompletionRequestResponseFormat] = None
if response_json_schema:
response_format = {
"type": "json_object",
"schema": response_json_schema,
}
chat_completion = self.models[model_name].create_chat_completion(
messages,
tools=tools,
@@ -452,6 +466,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=max_tokens,
grammar=grammar,
stream=True,
response_format=response_format,
)
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:

View File

@@ -33,6 +33,7 @@ from custom_components.llama_conversation.const import (
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
@@ -216,13 +217,14 @@ class OllamaAPIClient(LocalLLMClient):
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
client = self._build_client(timeout=timeout)
try:
format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None)
stream = await client.chat(
model=model_name,
messages=messages,
tools=tools,
stream=True,
think=think_mode,
format="json" if json_mode else None,
format=format_option,
options=options,
keep_alive=keep_alive_payload,
)

View File

@@ -46,6 +46,12 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_AI_TASK_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
@@ -150,7 +156,7 @@ from .const import (
DEFAULT_OPTIONS,
option_overrides,
RECOMMENDED_CHAT_MODELS,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
)
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
@@ -400,7 +406,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Return subentries supported by this integration."""
return {
"conversation": LocalLLMSubentryFlowHandler,
# "ai_task_data": LocalLLMSubentryFlowHandler,
"ai_task_data": LocalLLMSubentryFlowHandler,
}
@@ -584,69 +590,98 @@ def local_llama_config_option_schema(
subentry_type: str,
) -> dict:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
is_ai_task = subentry_type == "ai_task_data"
default_prompt = DEFAULT_AI_TASK_PROMPT if is_ai_task else build_prompt_template(language, DEFAULT_PROMPT)
prompt_key = CONF_AI_TASK_PROMPT if is_ai_task else CONF_PROMPT
prompt_selector = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)) if is_ai_task else TemplateSelector()
result: dict = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required(
CONF_THINKING_PREFIX,
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
default=DEFAULT_THINKING_PREFIX,
): str,
vol.Required(
CONF_THINKING_SUFFIX,
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
default=DEFAULT_THINKING_SUFFIX,
): str,
vol.Required(
CONF_TOOL_CALL_PREFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
default=DEFAULT_TOOL_CALL_PREFIX,
): str,
vol.Required(
CONF_TOOL_CALL_SUFFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
default=DEFAULT_TOOL_CALL_SUFFIX,
): str,
vol.Required(
CONF_ENABLE_LEGACY_TOOL_CALLING,
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
): bool,
}
if is_ai_task:
result: dict = {
vol.Optional(
prompt_key,
description={"suggested_value": options.get(prompt_key, default_prompt)},
default=options.get(prompt_key, default_prompt),
): prompt_selector,
vol.Required(
CONF_AI_TASK_EXTRACTION_METHOD,
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
): SelectSelector(SelectSelectorConfig(
options=[
SelectOptionDict(value="none", label="None"),
SelectOptionDict(value="structure", label="Structured output"),
SelectOptionDict(value="tool", label="Tool call"),
],
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(
CONF_AI_TASK_RETRIES,
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
}
else:
result: dict = {
vol.Optional(
prompt_key,
description={"suggested_value": options.get(prompt_key, default_prompt)},
default=options.get(prompt_key, default_prompt),
): prompt_selector,
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required(
CONF_THINKING_PREFIX,
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
default=DEFAULT_THINKING_PREFIX,
): str,
vol.Required(
CONF_THINKING_SUFFIX,
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
default=DEFAULT_THINKING_SUFFIX,
): str,
vol.Required(
CONF_TOOL_CALL_PREFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
default=DEFAULT_TOOL_CALL_PREFIX,
): str,
vol.Required(
CONF_TOOL_CALL_SUFFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
default=DEFAULT_TOOL_CALL_SUFFIX,
): str,
vol.Required(
CONF_ENABLE_LEGACY_TOOL_CALLING,
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
): bool,
}
if backend_type == BACKEND_TYPE_LLAMA_CPP:
result.update({
vol.Required(
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
@@ -920,13 +955,17 @@ def local_llama_config_option_schema(
): int,
})
elif subentry_type == "ai_task_data":
pass # no additional options for ai_task_data for now
# no extra conversation/tool options for ai_task_data beyond schema defaults
pass
# sort the options
global_order = [
# general
CONF_LLM_HASS_API,
CONF_PROMPT,
CONF_AI_TASK_PROMPT,
CONF_AI_TASK_EXTRACTION_METHOD,
CONF_AI_TASK_RETRIES,
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
# sampling parameters
@@ -1116,8 +1155,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE]
is_ai_task = self._subentry_type == "ai_task_data"
if CONF_PROMPT not in self.model_config:
if is_ai_task:
if CONF_AI_TASK_PROMPT not in self.model_config:
self.model_config[CONF_AI_TASK_PROMPT] = DEFAULT_AI_TASK_PROMPT
if CONF_AI_TASK_RETRIES not in self.model_config:
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
elif CONF_PROMPT not in self.model_config:
# determine selected language from model config or parent options
selected_language = self.model_config.get(
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
@@ -1150,20 +1197,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
)
if user_input:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if not is_ai_task:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
for key in (
@@ -1172,6 +1220,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
CONF_REQUEST_TIMEOUT,
CONF_AI_TASK_RETRIES,
):
if key in user_input:
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)

View File

@@ -8,6 +8,12 @@ SERVICE_TOOL_NAME = "HassCallService"
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
CONF_PROMPT = "prompt"
CONF_AI_TASK_PROMPT = "ai_task_prompt"
DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data."
CONF_AI_TASK_RETRIES = "ai_task_retries"
DEFAULT_AI_TASK_RETRIES = 0
CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method"
DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure"
PERSONA_PROMPTS = {
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
"de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.",
@@ -188,6 +194,7 @@ CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
CONF_CONTEXT_LENGTH = "context_length"
DEFAULT_CONTEXT_LENGTH = 8192
CONF_RESPONSE_JSON_SCHEMA = "response_json_schema"
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
DEFAULT_LLAMACPP_BATCH_SIZE = 512
CONF_LLAMACPP_THREAD_COUNT = "n_threads"

View File

@@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs
from homeassistant.util import color
from homeassistant.util.package import install_package, is_installed
from voluptuous_openapi import convert
from voluptuous_openapi import convert as convert_to_openapi
from .const import (
DOMAIN,
@@ -285,7 +285,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
"parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ])
else:
@@ -294,7 +294,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
})

View File

@@ -0,0 +1,153 @@
"""Tests for AI Task extraction behavior."""
from typing import Any, cast
import pytest
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.ai_task import (
LocalLLMTaskEntity,
ResultExtractionMethod,
)
from custom_components.llama_conversation.const import (
CONF_AI_TASK_EXTRACTION_METHOD,
)
from custom_components.llama_conversation.entity import TextGenerationResult
class DummyGenTask:
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
self.name = name
self.instructions = instructions
self.attachments = attachments or []
self.structure = structure
class DummyChatLog:
def __init__(self, content=None):
self.content = content or []
self.conversation_id = "conv-id"
self.llm_api = None
class DummyClient:
def __init__(self, result: TextGenerationResult):
self._result = result
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
return False
async def _generate(self, _messages, _llm_api, _entity_id, _options):
return self._result
class DummyTaskEntity(LocalLLMTaskEntity):
def __init__(self, hass, client, options):
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
self.hass = hass
self.client = client
self._runtime_options = options
self.entity_id = "ai_task.test"
self.entry_id = "entry"
self.subentry_id = "subentry"
self._attr_supported_features = 0
@property
def runtime_options(self):
return self._runtime_options
@pytest.mark.asyncio
async def test_no_extraction_returns_raw_text(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="raw text")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
)
chat_log = DummyChatLog()
task = DummyGenTask()
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == "raw text"
assert chat_log.llm_api is None
@pytest.mark.asyncio
async def test_structured_output_success(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response='{"foo": 1}')),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"foo": int})
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"foo": 1}
@pytest.mark.asyncio
async def test_structured_output_invalid_json_raises(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="not-json")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"foo": int})
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_success(hass):
tool_call = llm.ToolInput("submit_response", {"value": 42})
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"value": int})
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"value": 42}
assert chat_log.llm_api is not None
@pytest.mark.asyncio
async def test_tool_extraction_missing_tool_args_raises(hass):
class DummyToolCall:
def __init__(self, tool_args=None):
self.tool_args = tool_args
tool_call = DummyToolCall(tool_args=None)
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure={"value": int})
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_requires_structure(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=None)
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))