mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
support structured ouput for AI tasks
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ main.log
|
||||
*.xlsx
|
||||
notes.txt
|
||||
runpod_bootstrap.sh
|
||||
*.code-workspace
|
||||
*.code-workspace
|
||||
.coverage
|
||||
@@ -5,18 +5,30 @@ from __future__ import annotations
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert as convert_to_openapi
|
||||
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, Context
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
|
||||
from .entity import LocalLLMEntity, LocalLLMClient
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_PROMPT,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
CONF_AI_TASK_PROMPT,
|
||||
DEFAULT_AI_TASK_PROMPT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
DEFAULT_AI_TASK_RETRIES,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -37,90 +49,221 @@ async def async_setup_entry(
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class ResultExtractionMethod(StrEnum):
|
||||
NONE = "none"
|
||||
STRUCTURED_OUTPUT = "structure"
|
||||
TOOL = "tool"
|
||||
|
||||
class SubmitResponseTool(llm.Tool):
|
||||
name = "submit_response"
|
||||
description = "Submit the structured response payload for the AI task"
|
||||
|
||||
def __init__(self, parameters_schema: vol.Schema):
|
||||
self.parameters = parameters_schema
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> dict:
|
||||
return tool_input.tool_args or {}
|
||||
|
||||
|
||||
class SubmitResponseAPI(llm.API):
|
||||
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
|
||||
self._tools = tools
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id=f"{DOMAIN}-ai-task-tool",
|
||||
name="AI Task Tool API",
|
||||
)
|
||||
|
||||
async def async_get_api_instance(
|
||||
self, llm_context: llm.LLMContext
|
||||
) -> llm.APIInstance:
|
||||
return llm.APIInstance(
|
||||
api=self,
|
||||
api_prompt="Call submit_response to return the structured AI task result.",
|
||||
llm_context=llm_context,
|
||||
tools=self._tools,
|
||||
)
|
||||
|
||||
|
||||
class LocalLLMTaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
LocalLLMEntity,
|
||||
):
|
||||
"""Ollama AI Task entity."""
|
||||
"""AI Task entity."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize Ollama AI Task entity."""
|
||||
"""Initialize AI Task entity."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.client._supports_vision(self.runtime_options):
|
||||
self._attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA |
|
||||
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
else:
|
||||
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _generate_once(
|
||||
self,
|
||||
message_history: list[conversation.Content],
|
||||
chat_log: conversation.ChatLog,
|
||||
entity_options: dict[str, Any],
|
||||
) -> tuple[str, list | None]:
|
||||
"""Generate a single response from the LLM."""
|
||||
collected: list[str] = []
|
||||
collected_tools = None
|
||||
|
||||
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
|
||||
if hasattr(self.client, "_generate_stream"):
|
||||
async for chunk in self.client._generate_stream(
|
||||
message_history,
|
||||
chat_log.llm_api,
|
||||
self.entity_id,
|
||||
entity_options,
|
||||
):
|
||||
if chunk.response:
|
||||
collected.append(chunk.response)
|
||||
if chunk.tool_calls:
|
||||
collected_tools = chunk.tool_calls
|
||||
else:
|
||||
blocking_result = await self.client._generate(
|
||||
message_history,
|
||||
chat_log.llm_api,
|
||||
self.entity_id,
|
||||
entity_options,
|
||||
)
|
||||
if blocking_result.response:
|
||||
collected.append(blocking_result.response)
|
||||
if blocking_result.tool_calls:
|
||||
collected_tools = blocking_result.tool_calls
|
||||
|
||||
text = "".join(collected).strip()
|
||||
return text, collected_tools
|
||||
|
||||
def _extract_data(
|
||||
self,
|
||||
raw_text: str,
|
||||
tool_calls: list | None,
|
||||
extraction_method: ResultExtractionMethod,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Extract the final data from the LLM response based on the extraction method."""
|
||||
if extraction_method == ResultExtractionMethod.NONE:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=raw_text,
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
try:
|
||||
data = json_loads(raw_text)
|
||||
except JSONDecodeError as err:
|
||||
raise HomeAssistantError(
|
||||
"Error with Local LLM structured response"
|
||||
) from err
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.TOOL:
|
||||
first_tool = (tool_calls or [None])[0]
|
||||
if not first_tool or not getattr(first_tool, "tool_args", None):
|
||||
raise HomeAssistantError("Error with Local LLM tool response")
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=first_tool.tool_args,
|
||||
)
|
||||
|
||||
raise HomeAssistantError("Invalid extraction method for AI Task")
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
|
||||
extraction_method = ResultExtractionMethod.NONE
|
||||
|
||||
try:
|
||||
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
task_prompt = self.runtime_options.get(CONF_AI_TASK_PROMPT, DEFAULT_AI_TASK_PROMPT)
|
||||
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
|
||||
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
|
||||
max_attempts = retries + 1
|
||||
|
||||
message_history = chat_log.content[:]
|
||||
entity_options = {**self.runtime_options}
|
||||
if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure)
|
||||
|
||||
if not isinstance(message_history[0], conversation.SystemContent):
|
||||
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
|
||||
message_history.insert(0, system_prompt)
|
||||
|
||||
_LOGGER.debug(f"Generating response for {task.name=}...")
|
||||
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
|
||||
message_history = list(chat_log.content) if chat_log.content else []
|
||||
|
||||
assistant_message = await anext(generation_result)
|
||||
if not isinstance(assistant_message, conversation.AssistantContent):
|
||||
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
|
||||
text = assistant_message.content
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.NONE:
|
||||
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
|
||||
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.error(
|
||||
"Failed to parse JSON response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM structured response") from err
|
||||
elif extraction_method == ResultExtractionMethod.TOOL:
|
||||
try:
|
||||
data = assistant_message.tool_calls[0].tool_args
|
||||
except (IndexError, AttributeError) as err:
|
||||
_LOGGER.error(
|
||||
"Failed to extract tool arguments from response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM tool response") from err
|
||||
system_message = conversation.SystemContent(content=task_prompt)
|
||||
if message_history and isinstance(message_history[0], conversation.SystemContent):
|
||||
message_history[0] = system_message
|
||||
else:
|
||||
raise ValueError() # should not happen
|
||||
message_history.insert(0, system_message)
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
||||
if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
|
||||
message_history.append(
|
||||
conversation.UserContent(
|
||||
content=task.instructions, attachments=task.attachments
|
||||
)
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure:
|
||||
raise HomeAssistantError(
|
||||
"Structured extraction selected but no task structure was provided"
|
||||
)
|
||||
if extraction_method == ResultExtractionMethod.TOOL:
|
||||
if not task.structure:
|
||||
raise HomeAssistantError(
|
||||
"Tool extraction selected but no task structure was provided"
|
||||
)
|
||||
|
||||
parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA)
|
||||
if isinstance(task.structure, dict):
|
||||
parameters_schema = vol.Schema(task.structure)
|
||||
|
||||
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance(
|
||||
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
|
||||
)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
_LOGGER.debug(
|
||||
"Generating response for %s (attempt %s/%s)...",
|
||||
task.name,
|
||||
attempt + 1,
|
||||
max_attempts,
|
||||
)
|
||||
text, tool_calls = await self._generate_once(message_history, chat_log, entity_options)
|
||||
return self._extract_data(text, tool_calls, extraction_method, chat_log)
|
||||
except HomeAssistantError as err:
|
||||
last_error = err
|
||||
if attempt < max_attempts - 1:
|
||||
continue
|
||||
raise
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
_LOGGER.exception(
|
||||
"Unhandled exception while running AI Task '%s'",
|
||||
task.name,
|
||||
)
|
||||
raise HomeAssistantError(
|
||||
f"Unhandled error while running AI Task '{task.name}'"
|
||||
) from err
|
||||
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
raise HomeAssistantError("AI Task generation failed without an error")
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
|
||||
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
|
||||
raise HomeAssistantError(
|
||||
f"Unhandled error while running AI Task '{task.name}'"
|
||||
) from err
|
||||
|
||||
@@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
@@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
entity_options: dict[str, Any],
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options[CONF_CHAT_MODEL]
|
||||
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
@@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
if response_json_schema:
|
||||
request_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ha_task",
|
||||
"schema": response_json_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
tools = None
|
||||
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
|
||||
# most local backends absolutely butcher any sort of prompt formatting when using tool calling
|
||||
@@ -258,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
try:
|
||||
if msg.role == "user":
|
||||
input_text = msg.content
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -358,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
entity_options: dict[str, Any],
|
||||
) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
|
||||
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL)
|
||||
@@ -368,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
request_params: Dict[str, Any] = {
|
||||
"model": model_name,
|
||||
}
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
if response_json_schema:
|
||||
request_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ha_task",
|
||||
"schema": response_json_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers: Dict[str, Any] = {}
|
||||
|
||||
@@ -57,6 +57,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DOMAIN,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
)
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
|
||||
|
||||
@@ -64,10 +65,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
|
||||
from typing import TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType
|
||||
from llama_cpp import (
|
||||
Llama as LlamaType,
|
||||
LlamaGrammar as LlamaGrammarType,
|
||||
ChatCompletionRequestResponseFormat
|
||||
)
|
||||
else:
|
||||
LlamaType = Any
|
||||
LlamaGrammarType = Any
|
||||
ChatCompletionRequestResponseFormat = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -441,6 +447,14 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
|
||||
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
response_format: Optional[ChatCompletionRequestResponseFormat] = None
|
||||
if response_json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": response_json_schema,
|
||||
}
|
||||
|
||||
chat_completion = self.models[model_name].create_chat_completion(
|
||||
messages,
|
||||
tools=tools,
|
||||
@@ -452,6 +466,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
max_tokens=max_tokens,
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||
|
||||
@@ -33,6 +33,7 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
@@ -216,13 +217,14 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
||||
client = self._build_client(timeout=timeout)
|
||||
try:
|
||||
format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None)
|
||||
stream = await client.chat(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
think=think_mode,
|
||||
format="json" if json_mode else None,
|
||||
format=format_option,
|
||||
options=options,
|
||||
keep_alive=keep_alive_payload,
|
||||
)
|
||||
|
||||
@@ -46,6 +46,12 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_AI_TASK_PROMPT,
|
||||
DEFAULT_AI_TASK_PROMPT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
DEFAULT_AI_TASK_RETRIES,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
@@ -150,7 +156,7 @@ from .const import (
|
||||
DEFAULT_OPTIONS,
|
||||
option_overrides,
|
||||
RECOMMENDED_CHAT_MODELS,
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
)
|
||||
|
||||
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
|
||||
@@ -400,7 +406,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
"""Return subentries supported by this integration."""
|
||||
return {
|
||||
"conversation": LocalLLMSubentryFlowHandler,
|
||||
# "ai_task_data": LocalLLMSubentryFlowHandler,
|
||||
"ai_task_data": LocalLLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -584,69 +590,98 @@ def local_llama_config_option_schema(
|
||||
subentry_type: str,
|
||||
) -> dict:
|
||||
|
||||
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
|
||||
is_ai_task = subentry_type == "ai_task_data"
|
||||
default_prompt = DEFAULT_AI_TASK_PROMPT if is_ai_task else build_prompt_template(language, DEFAULT_PROMPT)
|
||||
prompt_key = CONF_AI_TASK_PROMPT if is_ai_task else CONF_PROMPT
|
||||
prompt_selector = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)) if is_ai_task else TemplateSelector()
|
||||
|
||||
result: dict = {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
|
||||
default=options.get(CONF_PROMPT, default_prompt),
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
|
||||
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
||||
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
||||
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
|
||||
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
|
||||
vol.Required(
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
|
||||
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
): TextSelector(TextSelectorConfig(multiple=True)),
|
||||
vol.Required(
|
||||
CONF_THINKING_PREFIX,
|
||||
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
|
||||
default=DEFAULT_THINKING_PREFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_THINKING_SUFFIX,
|
||||
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
|
||||
default=DEFAULT_THINKING_SUFFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_TOOL_CALL_PREFIX,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
|
||||
default=DEFAULT_TOOL_CALL_PREFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_TOOL_CALL_SUFFIX,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
|
||||
default=DEFAULT_TOOL_CALL_SUFFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
||||
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
||||
): bool,
|
||||
}
|
||||
if is_ai_task:
|
||||
result: dict = {
|
||||
vol.Optional(
|
||||
prompt_key,
|
||||
description={"suggested_value": options.get(prompt_key, default_prompt)},
|
||||
default=options.get(prompt_key, default_prompt),
|
||||
): prompt_selector,
|
||||
vol.Required(
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
|
||||
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
|
||||
): SelectSelector(SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(value="none", label="None"),
|
||||
SelectOptionDict(value="structure", label="Structured output"),
|
||||
SelectOptionDict(value="tool", label="Tool call"),
|
||||
],
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
vol.Required(
|
||||
CONF_AI_TASK_RETRIES,
|
||||
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
|
||||
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
|
||||
}
|
||||
else:
|
||||
result: dict = {
|
||||
vol.Optional(
|
||||
prompt_key,
|
||||
description={"suggested_value": options.get(prompt_key, default_prompt)},
|
||||
default=options.get(prompt_key, default_prompt),
|
||||
): prompt_selector,
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
|
||||
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
||||
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
||||
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
|
||||
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
|
||||
vol.Required(
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
|
||||
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
): TextSelector(TextSelectorConfig(multiple=True)),
|
||||
vol.Required(
|
||||
CONF_THINKING_PREFIX,
|
||||
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
|
||||
default=DEFAULT_THINKING_PREFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_THINKING_SUFFIX,
|
||||
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
|
||||
default=DEFAULT_THINKING_SUFFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_TOOL_CALL_PREFIX,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
|
||||
default=DEFAULT_TOOL_CALL_PREFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_TOOL_CALL_SUFFIX,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
|
||||
default=DEFAULT_TOOL_CALL_SUFFIX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
||||
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
||||
): bool,
|
||||
}
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
result.update({
|
||||
vol.Required(
|
||||
vol.Required(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
@@ -920,13 +955,17 @@ def local_llama_config_option_schema(
|
||||
): int,
|
||||
})
|
||||
elif subentry_type == "ai_task_data":
|
||||
pass # no additional options for ai_task_data for now
|
||||
# no extra conversation/tool options for ai_task_data beyond schema defaults
|
||||
pass
|
||||
|
||||
# sort the options
|
||||
global_order = [
|
||||
# general
|
||||
CONF_LLM_HASS_API,
|
||||
CONF_PROMPT,
|
||||
CONF_AI_TASK_PROMPT,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_MAX_TOKENS,
|
||||
# sampling parameters
|
||||
@@ -1116,8 +1155,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
description_placeholders = {}
|
||||
entry = self._get_entry()
|
||||
backend_type = entry.data[CONF_BACKEND_TYPE]
|
||||
is_ai_task = self._subentry_type == "ai_task_data"
|
||||
|
||||
if CONF_PROMPT not in self.model_config:
|
||||
if is_ai_task:
|
||||
if CONF_AI_TASK_PROMPT not in self.model_config:
|
||||
self.model_config[CONF_AI_TASK_PROMPT] = DEFAULT_AI_TASK_PROMPT
|
||||
if CONF_AI_TASK_RETRIES not in self.model_config:
|
||||
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
|
||||
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
|
||||
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
|
||||
elif CONF_PROMPT not in self.model_config:
|
||||
# determine selected language from model config or parent options
|
||||
selected_language = self.model_config.get(
|
||||
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
|
||||
@@ -1150,20 +1197,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
)
|
||||
|
||||
if user_input:
|
||||
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
||||
errors["base"] = "sys_refresh_caching_enabled"
|
||||
if not is_ai_task:
|
||||
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
||||
errors["base"] = "sys_refresh_caching_enabled"
|
||||
|
||||
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
||||
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
||||
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_icl_file"
|
||||
description_placeholders["filename"] = filename
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_icl_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
|
||||
for key in (
|
||||
@@ -1172,6 +1220,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
):
|
||||
if key in user_input:
|
||||
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
|
||||
|
||||
@@ -8,6 +8,12 @@ SERVICE_TOOL_NAME = "HassCallService"
|
||||
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
|
||||
CONF_PROMPT = "prompt"
|
||||
CONF_AI_TASK_PROMPT = "ai_task_prompt"
|
||||
DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data."
|
||||
CONF_AI_TASK_RETRIES = "ai_task_retries"
|
||||
DEFAULT_AI_TASK_RETRIES = 0
|
||||
CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method"
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure"
|
||||
PERSONA_PROMPTS = {
|
||||
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
|
||||
"de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.",
|
||||
@@ -188,6 +194,7 @@ CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
|
||||
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
|
||||
CONF_CONTEXT_LENGTH = "context_length"
|
||||
DEFAULT_CONTEXT_LENGTH = 8192
|
||||
CONF_RESPONSE_JSON_SCHEMA = "response_json_schema"
|
||||
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE = 512
|
||||
CONF_LLAMACPP_THREAD_COUNT = "n_threads"
|
||||
|
||||
@@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs
|
||||
from homeassistant.util import color
|
||||
from homeassistant.util.package import install_package, is_installed
|
||||
|
||||
from voluptuous_openapi import convert
|
||||
from voluptuous_openapi import convert as convert_to_openapi
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
@@ -285,7 +285,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
"parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ])
|
||||
else:
|
||||
@@ -294,7 +294,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
153
tests/llama_conversation/test_ai_task.py
Normal file
153
tests/llama_conversation/test_ai_task.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for AI Task extraction behavior."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.ai_task import (
|
||||
LocalLLMTaskEntity,
|
||||
ResultExtractionMethod,
|
||||
)
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
)
|
||||
from custom_components.llama_conversation.entity import TextGenerationResult
|
||||
|
||||
|
||||
class DummyGenTask:
|
||||
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
|
||||
self.name = name
|
||||
self.instructions = instructions
|
||||
self.attachments = attachments or []
|
||||
self.structure = structure
|
||||
|
||||
|
||||
class DummyChatLog:
|
||||
def __init__(self, content=None):
|
||||
self.content = content or []
|
||||
self.conversation_id = "conv-id"
|
||||
self.llm_api = None
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, result: TextGenerationResult):
|
||||
self._result = result
|
||||
|
||||
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
|
||||
return False
|
||||
|
||||
async def _generate(self, _messages, _llm_api, _entity_id, _options):
|
||||
return self._result
|
||||
|
||||
|
||||
class DummyTaskEntity(LocalLLMTaskEntity):
|
||||
def __init__(self, hass, client, options):
|
||||
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
|
||||
self.hass = hass
|
||||
self.client = client
|
||||
self._runtime_options = options
|
||||
self.entity_id = "ai_task.test"
|
||||
self.entry_id = "entry"
|
||||
self.subentry_id = "subentry"
|
||||
self._attr_supported_features = 0
|
||||
|
||||
@property
|
||||
def runtime_options(self):
|
||||
return self._runtime_options
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_extraction_returns_raw_text(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="raw text")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask()
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == "raw text"
|
||||
assert chat_log.llm_api is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_success(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response='{"foo": 1}')),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"foo": int})
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == {"foo": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_invalid_json_raises(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="not-json")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"foo": int})
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_success(hass):
|
||||
tool_call = llm.ToolInput("submit_response", {"value": 42})
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"value": int})
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == {"value": 42}
|
||||
assert chat_log.llm_api is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_missing_tool_args_raises(hass):
|
||||
class DummyToolCall:
|
||||
def __init__(self, tool_args=None):
|
||||
self.tool_args = tool_args
|
||||
|
||||
tool_call = DummyToolCall(tool_args=None)
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure={"value": int})
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_requires_structure(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=None)
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
Reference in New Issue
Block a user