diff --git a/.gitignore b/.gitignore index 5a0ecbf..8b54951 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ main.log *.xlsx notes.txt runpod_bootstrap.sh -*.code-workspace \ No newline at end of file +*.code-workspace +.coverage \ No newline at end of file diff --git a/custom_components/llama_conversation/ai_task.py b/custom_components/llama_conversation/ai_task.py index b3ea222..337d25e 100644 --- a/custom_components/llama_conversation/ai_task.py +++ b/custom_components/llama_conversation/ai_task.py @@ -5,18 +5,30 @@ from __future__ import annotations from json import JSONDecodeError import logging from enum import StrEnum +from typing import Any, cast +import voluptuous as vol +from voluptuous_openapi import convert as convert_to_openapi + +from homeassistant.helpers import llm from homeassistant.components import ai_task, conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant, Context +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.util.json import json_loads + from .entity import LocalLLMEntity, LocalLLMClient from .const import ( - CONF_PROMPT, - DEFAULT_PROMPT, + CONF_RESPONSE_JSON_SCHEMA, + CONF_AI_TASK_PROMPT, + DEFAULT_AI_TASK_PROMPT, + CONF_AI_TASK_RETRIES, + DEFAULT_AI_TASK_RETRIES, + CONF_AI_TASK_EXTRACTION_METHOD, + DEFAULT_AI_TASK_EXTRACTION_METHOD, + DOMAIN, ) _LOGGER = logging.getLogger(__name__) @@ -37,90 +49,221 @@ async def async_setup_entry( config_subentry_id=subentry.subentry_id, ) + class ResultExtractionMethod(StrEnum): NONE = "none" STRUCTURED_OUTPUT = "structure" TOOL = "tool" +class SubmitResponseTool(llm.Tool): + name = "submit_response" + description = "Submit the structured response payload for the AI task" + + def __init__(self, parameters_schema: vol.Schema): + self.parameters = parameters_schema + + async def async_call( + self, + hass: HomeAssistant, + tool_input: llm.ToolInput, + llm_context: llm.LLMContext, + ) -> dict: + return tool_input.tool_args or {} + + +class SubmitResponseAPI(llm.API): + def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None: + self._tools = tools + super().__init__( + hass=hass, + id=f"{DOMAIN}-ai-task-tool", + name="AI Task Tool API", + ) + + async def async_get_api_instance( + self, llm_context: llm.LLMContext + ) -> llm.APIInstance: + return llm.APIInstance( + api=self, + api_prompt="Call submit_response to return the structured AI task result.", + llm_context=llm_context, + tools=self._tools, + ) + + class LocalLLMTaskEntity( ai_task.AITaskEntity, LocalLLMEntity, ): - """Ollama AI Task entity.""" + """AI Task entity.""" def __init__(self, *args, **kwargs) -> None: - """Initialize Ollama AI Task entity.""" + """Initialize AI Task entity.""" super().__init__(*args, **kwargs) if self.client._supports_vision(self.runtime_options): self._attr_supported_features = ( - ai_task.AITaskEntityFeature.GENERATE_DATA | - ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS + ai_task.AITaskEntityFeature.GENERATE_DATA + | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS ) else: self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA + async def _generate_once( + self, + message_history: list[conversation.Content], + chat_log: conversation.ChatLog, + entity_options: dict[str, Any], + ) -> tuple[str, list | None]: + """Generate a single response from the LLM.""" + collected: list[str] = [] + collected_tools = None + + # call the LLM client directly (not _async_generate) since that will attempt to execute tool calls + if hasattr(self.client, "_generate_stream"): + async for chunk in self.client._generate_stream( + message_history, + chat_log.llm_api, + self.entity_id, + entity_options, + ): + if chunk.response: + collected.append(chunk.response) + if chunk.tool_calls: + collected_tools = chunk.tool_calls + else: + blocking_result = await self.client._generate( + message_history, + chat_log.llm_api, + self.entity_id, + entity_options, + ) + if blocking_result.response: + collected.append(blocking_result.response) + if blocking_result.tool_calls: + collected_tools = blocking_result.tool_calls + + text = "".join(collected).strip() + return text, collected_tools + + def _extract_data( + self, + raw_text: str, + tool_calls: list | None, + extraction_method: ResultExtractionMethod, + chat_log: conversation.ChatLog, + ) -> ai_task.GenDataTaskResult: + """Extract the final data from the LLM response based on the extraction method.""" + if extraction_method == ResultExtractionMethod.NONE: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=raw_text, + ) + + if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT: + try: + data = json_loads(raw_text) + except JSONDecodeError as err: + raise HomeAssistantError( + "Error with Local LLM structured response" + ) from err + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=data, + ) + + if extraction_method == ResultExtractionMethod.TOOL: + first_tool = (tool_calls or [None])[0] + if not first_tool or not getattr(first_tool, "tool_args", None): + raise HomeAssistantError("Error with Local LLM tool response") + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=first_tool.tool_args, + ) + + raise HomeAssistantError("Invalid extraction method for AI Task") + async def _async_generate_data( self, task: ai_task.GenDataTask, chat_log: conversation.ChatLog, ) -> ai_task.GenDataTaskResult: """Handle a generate data task.""" - - extraction_method = ResultExtractionMethod.NONE - try: - raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT) + task_prompt = self.runtime_options.get(CONF_AI_TASK_PROMPT, DEFAULT_AI_TASK_PROMPT) + retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)) + extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD) + max_attempts = retries + 1 - message_history = chat_log.content[:] + entity_options = {**self.runtime_options} + if task.structure and extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT: + entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure) - if not isinstance(message_history[0], conversation.SystemContent): - system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options)) - message_history.insert(0, system_prompt) - - _LOGGER.debug(f"Generating response for {task.name=}...") - generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options) + message_history = list(chat_log.content) if chat_log.content else [] - assistant_message = await anext(generation_result) - if not isinstance(assistant_message, conversation.AssistantContent): - raise HomeAssistantError("Last content in chat log is not an AssistantContent!") - text = assistant_message.content - - if not task.structure: - return ai_task.GenDataTaskResult( - conversation_id=chat_log.conversation_id, - data=text, - ) - - if extraction_method == ResultExtractionMethod.NONE: - raise HomeAssistantError("Task structure provided but no extraction method was specified!") - elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT: - try: - data = json_loads(text) - except JSONDecodeError as err: - _LOGGER.error( - "Failed to parse JSON response: %s. Response: %s", - err, - text, - ) - raise HomeAssistantError("Error with Local LLM structured response") from err - elif extraction_method == ResultExtractionMethod.TOOL: - try: - data = assistant_message.tool_calls[0].tool_args - except (IndexError, AttributeError) as err: - _LOGGER.error( - "Failed to extract tool arguments from response: %s. Response: %s", - err, - text, - ) - raise HomeAssistantError("Error with Local LLM tool response") from err + system_message = conversation.SystemContent(content=task_prompt) + if message_history and isinstance(message_history[0], conversation.SystemContent): + message_history[0] = system_message else: - raise ValueError() # should not happen + message_history.insert(0, system_message) - return ai_task.GenDataTaskResult( - conversation_id=chat_log.conversation_id, - data=data, - ) + if not any(isinstance(msg, conversation.UserContent) for msg in message_history): + message_history.append( + conversation.UserContent( + content=task.instructions, attachments=task.attachments + ) + ) + + if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT and not task.structure: + raise HomeAssistantError( + "Structured extraction selected but no task structure was provided" + ) + if extraction_method == ResultExtractionMethod.TOOL: + if not task.structure: + raise HomeAssistantError( + "Tool extraction selected but no task structure was provided" + ) + + parameters_schema = vol.Schema({}, extra=vol.ALLOW_EXTRA) + if isinstance(task.structure, dict): + parameters_schema = vol.Schema(task.structure) + + chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(parameters_schema)]).async_get_api_instance( + llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None) + ) + + last_error: Exception | None = None + for attempt in range(max_attempts): + try: + _LOGGER.debug( + "Generating response for %s (attempt %s/%s)...", + task.name, + attempt + 1, + max_attempts, + ) + text, tool_calls = await self._generate_once(message_history, chat_log, entity_options) + return self._extract_data(text, tool_calls, extraction_method, chat_log) + except HomeAssistantError as err: + last_error = err + if attempt < max_attempts - 1: + continue + raise + except Exception as err: + last_error = err + _LOGGER.exception( + "Unhandled exception while running AI Task '%s'", + task.name, + ) + raise HomeAssistantError( + f"Unhandled error while running AI Task '{task.name}'" + ) from err + + if last_error: + raise last_error + + raise HomeAssistantError("AI Task generation failed without an error") except Exception as err: _LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name) - raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err + raise HomeAssistantError( + f"Unhandled error while running AI Task '{task.name}'" + ) from err diff --git a/custom_components/llama_conversation/backends/generic_openai.py b/custom_components/llama_conversation/backends/generic_openai.py index 6d1631c..3d5a88b 100644 --- a/custom_components/llama_conversation/backends/generic_openai.py +++ b/custom_components/llama_conversation/backends/generic_openai.py @@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import ( CONF_REMEMBER_CONVERSATION_TIME_MINUTES, CONF_GENERIC_OPENAI_PATH, CONF_ENABLE_LEGACY_TOOL_CALLING, + CONF_RESPONSE_JSON_SCHEMA, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, @@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient): conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, - entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]: + entity_options: dict[str, Any], + ) -> AsyncGenerator[TextGenerationResult, None]: model_name = entity_options[CONF_CHAT_MODEL] temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P) @@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient): "messages": messages } + response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA) + if response_json_schema: + request_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "ha_task", + "schema": response_json_schema, + "strict": True, + }, + } + tools = None # "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools" # most local backends absolutely butcher any sort of prompt formatting when using tool calling @@ -258,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient): try: if msg.role == "user": input_text = msg.content - break except Exception: continue @@ -358,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient): conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, - entity_options: dict[str, Any]) -> TextGenerationResult: + entity_options: dict[str, Any], + ) -> TextGenerationResult: """Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream).""" model_name = entity_options.get(CONF_CHAT_MODEL) @@ -368,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient): request_params: Dict[str, Any] = { "model": model_name, } + response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA) + if response_json_schema: + request_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "ha_task", + "schema": response_json_schema, + "strict": True, + }, + } request_params.update(additional_params) headers: Dict[str, Any] = {} diff --git a/custom_components/llama_conversation/backends/llamacpp.py b/custom_components/llama_conversation/backends/llamacpp.py index 7a494bf..ed0e59b 100644 --- a/custom_components/llama_conversation/backends/llamacpp.py +++ b/custom_components/llama_conversation/backends/llamacpp.py @@ -57,6 +57,7 @@ from custom_components.llama_conversation.const import ( DEFAULT_LLAMACPP_THREAD_COUNT, DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, DOMAIN, + CONF_RESPONSE_JSON_SCHEMA, ) from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult @@ -64,10 +65,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene from typing import TYPE_CHECKING from types import ModuleType if TYPE_CHECKING: - from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType + from llama_cpp import ( + Llama as LlamaType, + LlamaGrammar as LlamaGrammarType, + ChatCompletionRequestResponseFormat + ) else: LlamaType = Any LlamaGrammarType = Any + ChatCompletionRequestResponseFormat = Any _LOGGER = logging.getLogger(__name__) @@ -441,6 +447,14 @@ class LlamaCppClient(LocalLLMClient): _LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...") + response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA) + response_format: Optional[ChatCompletionRequestResponseFormat] = None + if response_json_schema: + response_format = { + "type": "json_object", + "schema": response_json_schema, + } + chat_completion = self.models[model_name].create_chat_completion( messages, tools=tools, @@ -452,6 +466,7 @@ class LlamaCppClient(LocalLLMClient): max_tokens=max_tokens, grammar=grammar, stream=True, + response_format=response_format, ) def next_token() -> Generator[tuple[Optional[str], Optional[List]]]: diff --git a/custom_components/llama_conversation/backends/ollama.py b/custom_components/llama_conversation/backends/ollama.py index e62d40d..d2c75e7 100644 --- a/custom_components/llama_conversation/backends/ollama.py +++ b/custom_components/llama_conversation/backends/ollama.py @@ -33,6 +33,7 @@ from custom_components.llama_conversation.const import ( CONF_OLLAMA_JSON_MODE, CONF_CONTEXT_LENGTH, CONF_ENABLE_LEGACY_TOOL_CALLING, + CONF_RESPONSE_JSON_SCHEMA, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, @@ -216,13 +217,14 @@ class OllamaAPIClient(LocalLLMClient): async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]: client = self._build_client(timeout=timeout) try: + format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None) stream = await client.chat( model=model_name, messages=messages, tools=tools, stream=True, think=think_mode, - format="json" if json_mode else None, + format=format_option, options=options, keep_alive=keep_alive_payload, ) diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index f4d2ead..5946bf2 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -46,6 +46,12 @@ from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, + CONF_AI_TASK_PROMPT, + DEFAULT_AI_TASK_PROMPT, + CONF_AI_TASK_RETRIES, + DEFAULT_AI_TASK_RETRIES, + CONF_AI_TASK_EXTRACTION_METHOD, + DEFAULT_AI_TASK_EXTRACTION_METHOD, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, @@ -150,7 +156,7 @@ from .const import ( DEFAULT_OPTIONS, option_overrides, RECOMMENDED_CHAT_MODELS, - EMBEDDED_LLAMA_CPP_PYTHON_VERSION + EMBEDDED_LLAMA_CPP_PYTHON_VERSION, ) from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS @@ -400,7 +406,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN): """Return subentries supported by this integration.""" return { "conversation": LocalLLMSubentryFlowHandler, - # "ai_task_data": LocalLLMSubentryFlowHandler, + "ai_task_data": LocalLLMSubentryFlowHandler, } @@ -584,69 +590,98 @@ def local_llama_config_option_schema( subentry_type: str, ) -> dict: - default_prompt = build_prompt_template(language, DEFAULT_PROMPT) + is_ai_task = subentry_type == "ai_task_data" + default_prompt = DEFAULT_AI_TASK_PROMPT if is_ai_task else build_prompt_template(language, DEFAULT_PROMPT) + prompt_key = CONF_AI_TASK_PROMPT if is_ai_task else CONF_PROMPT + prompt_selector = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)) if is_ai_task else TemplateSelector() - result: dict = { - vol.Optional( - CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT, default_prompt)}, - default=options.get(CONF_PROMPT, default_prompt), - ): TemplateSelector(), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)}, - default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE), - ): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)), - vol.Required( - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, - description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)}, - default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, - ): BooleanSelector(BooleanSelectorConfig()), - vol.Required( - CONF_IN_CONTEXT_EXAMPLES_FILE, - description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)}, - default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE, - ): str, - vol.Required( - CONF_NUM_IN_CONTEXT_EXAMPLES, - description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)}, - default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES, - ): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)), - vol.Required( - CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, - description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)}, - default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - ): TextSelector(TextSelectorConfig(multiple=True)), - vol.Required( - CONF_THINKING_PREFIX, - description={"suggested_value": options.get(CONF_THINKING_PREFIX)}, - default=DEFAULT_THINKING_PREFIX, - ): str, - vol.Required( - CONF_THINKING_SUFFIX, - description={"suggested_value": options.get(CONF_THINKING_SUFFIX)}, - default=DEFAULT_THINKING_SUFFIX, - ): str, - vol.Required( - CONF_TOOL_CALL_PREFIX, - description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)}, - default=DEFAULT_TOOL_CALL_PREFIX, - ): str, - vol.Required( - CONF_TOOL_CALL_SUFFIX, - description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)}, - default=DEFAULT_TOOL_CALL_SUFFIX, - ): str, - vol.Required( - CONF_ENABLE_LEGACY_TOOL_CALLING, - description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)}, - default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING - ): bool, - } + if is_ai_task: + result: dict = { + vol.Optional( + prompt_key, + description={"suggested_value": options.get(prompt_key, default_prompt)}, + default=options.get(prompt_key, default_prompt), + ): prompt_selector, + vol.Required( + CONF_AI_TASK_EXTRACTION_METHOD, + description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)}, + default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD), + ): SelectSelector(SelectSelectorConfig( + options=[ + SelectOptionDict(value="none", label="None"), + SelectOptionDict(value="structure", label="Structured output"), + SelectOptionDict(value="tool", label="Tool call"), + ], + mode=SelectSelectorMode.DROPDOWN, + )), + vol.Required( + CONF_AI_TASK_RETRIES, + description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)}, + default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES), + ): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)), + } + else: + result: dict = { + vol.Optional( + prompt_key, + description={"suggested_value": options.get(prompt_key, default_prompt)}, + default=options.get(prompt_key, default_prompt), + ): prompt_selector, + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)}, + default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE), + ): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)), + vol.Required( + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, + description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)}, + default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, + ): BooleanSelector(BooleanSelectorConfig()), + vol.Required( + CONF_IN_CONTEXT_EXAMPLES_FILE, + description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)}, + default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE, + ): str, + vol.Required( + CONF_NUM_IN_CONTEXT_EXAMPLES, + description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)}, + default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES, + ): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)), + vol.Required( + CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, + description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)}, + default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, + ): TextSelector(TextSelectorConfig(multiple=True)), + vol.Required( + CONF_THINKING_PREFIX, + description={"suggested_value": options.get(CONF_THINKING_PREFIX)}, + default=DEFAULT_THINKING_PREFIX, + ): str, + vol.Required( + CONF_THINKING_SUFFIX, + description={"suggested_value": options.get(CONF_THINKING_SUFFIX)}, + default=DEFAULT_THINKING_SUFFIX, + ): str, + vol.Required( + CONF_TOOL_CALL_PREFIX, + description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)}, + default=DEFAULT_TOOL_CALL_PREFIX, + ): str, + vol.Required( + CONF_TOOL_CALL_SUFFIX, + description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)}, + default=DEFAULT_TOOL_CALL_SUFFIX, + ): str, + vol.Required( + CONF_ENABLE_LEGACY_TOOL_CALLING, + description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)}, + default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING + ): bool, + } if backend_type == BACKEND_TYPE_LLAMA_CPP: result.update({ - vol.Required( + vol.Required( CONF_MAX_TOKENS, description={"suggested_value": options.get(CONF_MAX_TOKENS)}, default=DEFAULT_MAX_TOKENS, @@ -920,13 +955,17 @@ def local_llama_config_option_schema( ): int, }) elif subentry_type == "ai_task_data": - pass # no additional options for ai_task_data for now + # no extra conversation/tool options for ai_task_data beyond schema defaults + pass # sort the options global_order = [ # general CONF_LLM_HASS_API, CONF_PROMPT, + CONF_AI_TASK_PROMPT, + CONF_AI_TASK_EXTRACTION_METHOD, + CONF_AI_TASK_RETRIES, CONF_CONTEXT_LENGTH, CONF_MAX_TOKENS, # sampling parameters @@ -1116,8 +1155,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow): description_placeholders = {} entry = self._get_entry() backend_type = entry.data[CONF_BACKEND_TYPE] + is_ai_task = self._subentry_type == "ai_task_data" - if CONF_PROMPT not in self.model_config: + if is_ai_task: + if CONF_AI_TASK_PROMPT not in self.model_config: + self.model_config[CONF_AI_TASK_PROMPT] = DEFAULT_AI_TASK_PROMPT + if CONF_AI_TASK_RETRIES not in self.model_config: + self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES + if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config: + self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD + elif CONF_PROMPT not in self.model_config: # determine selected language from model config or parent options selected_language = self.model_config.get( CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en") @@ -1150,20 +1197,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow): ) if user_input: - if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED): - errors["base"] = "sys_refresh_caching_enabled" + if not is_ai_task: + if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED): + errors["base"] = "sys_refresh_caching_enabled" - if user_input.get(CONF_USE_GBNF_GRAMMAR): - filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE) - if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): - errors["base"] = "missing_gbnf_file" - description_placeholders["filename"] = filename + if user_input.get(CONF_USE_GBNF_GRAMMAR): + filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE) + if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): + errors["base"] = "missing_gbnf_file" + description_placeholders["filename"] = filename - if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES): - filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE) - if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): - errors["base"] = "missing_icl_file" - description_placeholders["filename"] = filename + if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES): + filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE) + if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): + errors["base"] = "missing_icl_file" + description_placeholders["filename"] = filename # --- Normalize numeric fields to ints to avoid slice/type errors later --- for key in ( @@ -1172,6 +1220,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow): CONF_CONTEXT_LENGTH, CONF_MAX_TOKENS, CONF_REQUEST_TIMEOUT, + CONF_AI_TASK_RETRIES, ): if key in user_input: user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0) diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index d40f06c..cce65d4 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -8,6 +8,12 @@ SERVICE_TOOL_NAME = "HassCallService" SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"] SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"] CONF_PROMPT = "prompt" +CONF_AI_TASK_PROMPT = "ai_task_prompt" +DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data." +CONF_AI_TASK_RETRIES = "ai_task_retries" +DEFAULT_AI_TASK_RETRIES = 0 +CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method" +DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure" PERSONA_PROMPTS = { "en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.", "de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.", @@ -188,6 +194,7 @@ CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model" DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True CONF_CONTEXT_LENGTH = "context_length" DEFAULT_CONTEXT_LENGTH = 8192 +CONF_RESPONSE_JSON_SCHEMA = "response_json_schema" CONF_LLAMACPP_BATCH_SIZE = "batch_size" DEFAULT_LLAMACPP_BATCH_SIZE = 512 CONF_LLAMACPP_THREAD_COUNT = "n_threads" diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index fd124eb..4fe9b95 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs from homeassistant.util import color from homeassistant.util.package import install_package, is_installed -from voluptuous_openapi import convert +from voluptuous_openapi import convert as convert_to_openapi from .const import ( DOMAIN, @@ -285,7 +285,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis "function": { "name": tool["name"], "description": f"Call the Home Assistant service '{tool['name']}'", - "parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer) + "parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer) } } for tool in get_home_llm_tools(llm_api, domains) ]) else: @@ -294,7 +294,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis "function": { "name": tool.name, "description": tool.description or "", - "parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer) + "parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer) } }) diff --git a/tests/llama_conversation/test_ai_task.py b/tests/llama_conversation/test_ai_task.py new file mode 100644 index 0000000..09ff8ff --- /dev/null +++ b/tests/llama_conversation/test_ai_task.py @@ -0,0 +1,153 @@ +"""Tests for AI Task extraction behavior.""" + +from typing import Any, cast + +import pytest + +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm + +from custom_components.llama_conversation.ai_task import ( + LocalLLMTaskEntity, + ResultExtractionMethod, +) +from custom_components.llama_conversation.const import ( + CONF_AI_TASK_EXTRACTION_METHOD, +) +from custom_components.llama_conversation.entity import TextGenerationResult + + +class DummyGenTask: + def __init__(self, *, name="task", instructions="do", attachments=None, structure=None): + self.name = name + self.instructions = instructions + self.attachments = attachments or [] + self.structure = structure + + +class DummyChatLog: + def __init__(self, content=None): + self.content = content or [] + self.conversation_id = "conv-id" + self.llm_api = None + +class DummyClient: + def __init__(self, result: TextGenerationResult): + self._result = result + + def _supports_vision(self, _options): # pragma: no cover - not needed for tests + return False + + async def _generate(self, _messages, _llm_api, _entity_id, _options): + return self._result + + +class DummyTaskEntity(LocalLLMTaskEntity): + def __init__(self, hass, client, options): + # Bypass parent init to avoid ConfigEntry/Subentry plumbing. + self.hass = hass + self.client = client + self._runtime_options = options + self.entity_id = "ai_task.test" + self.entry_id = "entry" + self.subentry_id = "subentry" + self._attr_supported_features = 0 + + @property + def runtime_options(self): + return self._runtime_options + + +@pytest.mark.asyncio +async def test_no_extraction_returns_raw_text(hass): + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response="raw text")), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE}, + ) + chat_log = DummyChatLog() + task = DummyGenTask() + + result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log)) + + assert result.data == "raw text" + assert chat_log.llm_api is None + + +@pytest.mark.asyncio +async def test_structured_output_success(hass): + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response='{"foo": 1}')), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT}, + ) + chat_log = DummyChatLog() + task = DummyGenTask(structure={"foo": int}) + + result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log)) + + assert result.data == {"foo": 1} + + +@pytest.mark.asyncio +async def test_structured_output_invalid_json_raises(hass): + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response="not-json")), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT}, + ) + chat_log = DummyChatLog() + task = DummyGenTask(structure={"foo": int}) + + with pytest.raises(HomeAssistantError): + await entity._async_generate_data(cast(Any, task), cast(Any, chat_log)) + + +@pytest.mark.asyncio +async def test_tool_extraction_success(hass): + tool_call = llm.ToolInput("submit_response", {"value": 42}) + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL}, + ) + chat_log = DummyChatLog() + task = DummyGenTask(structure={"value": int}) + + result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log)) + + assert result.data == {"value": 42} + assert chat_log.llm_api is not None + + +@pytest.mark.asyncio +async def test_tool_extraction_missing_tool_args_raises(hass): + class DummyToolCall: + def __init__(self, tool_args=None): + self.tool_args = tool_args + + tool_call = DummyToolCall(tool_args=None) + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL}, + ) + chat_log = DummyChatLog() + task = DummyGenTask(structure={"value": int}) + + with pytest.raises(HomeAssistantError): + await entity._async_generate_data(cast(Any, task), cast(Any, chat_log)) + + +@pytest.mark.asyncio +async def test_tool_extraction_requires_structure(hass): + entity = DummyTaskEntity( + hass, + DummyClient(TextGenerationResult(response="")), + {CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL}, + ) + chat_log = DummyChatLog() + task = DummyGenTask(structure=None) + + with pytest.raises(HomeAssistantError): + await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))