Merge pull request #326 from acon96/release/v0.4.5

Release v0.4.5
This commit is contained in:
Alex O'Connell
2025-12-14 20:23:27 -05:00
committed by GitHub
27 changed files with 1640 additions and 1495 deletions

7
.gitignore vendored
View File

@@ -3,12 +3,13 @@ loras/
core/ core/
config/ config/
.DS_Store .DS_Store
data/*.json data/**/*.json
data/*.jsonl data/**/*.jsonl
*.pyc *.pyc
main.log main.log
.venv .venv
*.xlsx *.xlsx
notes.txt notes.txt
runpod_bootstrap.sh runpod_bootstrap.sh
*.code-workspace *.code-workspace
.coverage

View File

@@ -4,16 +4,17 @@ This project provides the required "glue" components to control your Home Assist
## Quick Start ## Quick Start
Please see the [Setup Guide](./docs/Setup.md) for more information on installation. Please see the [Setup Guide](./docs/Setup.md) for more information on installation.
## Local LLM Conversation Integration ## Local LLM Integration
**The latest version of this integration requires Home Assistant 2025.7.0 or newer** **The latest version of this integration requires Home Assistant 2025.7.0 or newer**
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent". In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent" or as an "ai task handler".
This component can be interacted with in a few ways: This component can be interacted with in a few ways:
- using a chat interface so you can chat with it. - using a chat interface so you can chat with it.
- integrating with Speech-to-Text and Text-to-Speech addons so you can just speak to it. - integrating with Speech-to-Text and Text-to-Speech addons so you can just speak to it.
- using automations or scripts to trigger "ai tasks"; these process input data with a prompt, and return structured data that can be used in further automations.
The integration can either run the model in 2 different ways: The integration can either run the model in a few ways:
1. Directly as part of the Home Assistant software using llama-cpp-python 1. Directly as part of the Home Assistant software using llama-cpp-python
2. On a separate machine using one of the following backends: 2. On a separate machine using one of the following backends:
- [Ollama](https://ollama.com/) (easier) - [Ollama](https://ollama.com/) (easier)
@@ -36,6 +37,7 @@ The latest models can be found on HuggingFace:
**Gemma3**: **Gemma3**:
1B: TBD 1B: TBD
270M: TBD
<details> <details>
@@ -158,6 +160,7 @@ python3 train.py \
## Version History ## Version History
| Version | Description | | Version | Description |
|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.4.5 | Add support for AI Task entities, Replace custom Ollama API implementation with the official `ollama-python` package to avoid future compatibility issues, Support multiple LLM APIs at once, Fix issues in tool call handling for various backends |
| v0.4.4 | Fix issue with OpenAI backends appending `/v1` to all URLs, and fix an issue with tools being serialized into the system prompt. | | v0.4.4 | Fix issue with OpenAI backends appending `/v1` to all URLs, and fix an issue with tools being serialized into the system prompt. |
| v0.4.3 | Fix an issue with the integration not creating model configs properly during setup | | v0.4.3 | Fix an issue with the integration not creating model configs properly during setup |
| v0.4.2 | Fix the following issues: not correctly setting default model settings during initial setup, non-integers being allowed in numeric config fields, being too strict with finish_reason requirements, and not letting the user clear the active LLM API | | v0.4.2 | Fix the following issues: not correctly setting default model settings during initial setup, non-integers being allowed in numeric config fields, being too strict with finish_reason requirements, and not letting the user clear the active LLM API |

View File

@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK) PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = { BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient, BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
@@ -184,6 +184,21 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
_LOGGER.debug("Migration to add downloaded model file complete") _LOGGER.debug("Migration to add downloaded model file complete")
if config_entry.version == 3 and config_entry.minor_version == 1:
# convert selected APIs from single value to list
api_to_convert = config_entry.options.get(CONF_LLM_HASS_API)
new_options = dict(config_entry.options)
if api_to_convert is not None:
new_options[CONF_LLM_HASS_API] = [api_to_convert]
else:
new_options[CONF_LLM_HASS_API] = []
hass.config_entries.async_update_entry(
config_entry, options=MappingProxyType(new_options)
)
hass.config_entries.async_update_entry(config_entry, minor_version=2)
return True return True
class HassServiceTool(llm.Tool): class HassServiceTool(llm.Tool):
@@ -255,14 +270,14 @@ class HomeLLMAPI(llm.API):
super().__init__( super().__init__(
hass=hass, hass=hass,
id=HOME_LLM_API_ID, id=HOME_LLM_API_ID,
name="Home-LLM (v1-v3)", name="Home Assistant Services",
) )
async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance: async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance:
"""Return the instance of the API.""" """Return the instance of the API."""
return llm.APIInstance( return llm.APIInstance(
api=self, api=self,
api_prompt="Call services in Home Assistant by passing the service name and the device to control.", api_prompt="Call services in Home Assistant by passing the service name and the device to control. Designed for Home-LLM Models (v1-v3)",
llm_context=llm_context, llm_context=llm_context,
tools=[HassServiceTool()], tools=[HassServiceTool()],
) )

View File

@@ -1,14 +1,18 @@
"""AI Task integration for Local LLMs.""" """AI Task integration for Local LLMs."""
from __future__ import annotations from __future__ import annotations
from json import JSONDecodeError from json import JSONDecodeError
import logging import logging
from enum import StrEnum from enum import StrEnum
from typing import Any
import voluptuous as vol
from voluptuous_openapi import convert as convert_to_openapi
from homeassistant.helpers import llm
from homeassistant.components import ai_task, conversation from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, Context from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
@@ -16,7 +20,13 @@ from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient from .entity import LocalLLMEntity, LocalLLMClient
from .const import ( from .const import (
CONF_PROMPT, CONF_PROMPT,
DEFAULT_PROMPT, CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
DOMAIN,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -29,98 +39,222 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up AI Task entities.""" """Set up AI Task entities."""
for subentry in config_entry.subentries.values(): for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data": if subentry.subentry_type != ai_task.DOMAIN:
continue continue
async_add_entities( # create one entity per subentry
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)], ai_task_entity = LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)
config_subentry_id=subentry.subentry_id,
) # make sure model is loaded
await config_entry.runtime_data._async_load_model(dict(subentry.data))
# register the ai task entity
async_add_entities([ai_task_entity], config_subentry_id=subentry.subentry_id)
class ResultExtractionMethod(StrEnum): class ResultExtractionMethod(StrEnum):
NONE = "none" NONE = "none"
STRUCTURED_OUTPUT = "structure" STRUCTURED_OUTPUT = "structure"
TOOL = "tool" TOOL = "tool"
class SubmitResponseTool(llm.Tool):
name = "submit_response"
description = "Submit the structured response payload for the AI task"
def __init__(self, parameters_schema: vol.Schema):
self.parameters = parameters_schema
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> dict:
return tool_input.tool_args or {}
class SubmitResponseAPI(llm.API):
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
self._tools = tools
super().__init__(
hass=hass,
id=f"{DOMAIN}-ai-task-tool",
name="AI Task Tool API",
)
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
return llm.APIInstance(
api=self,
api_prompt="Call submit_response to return the structured AI task result.",
llm_context=llm_context,
tools=self._tools,
custom_serializer=llm.selector_serializer,
)
class LocalLLMTaskEntity( class LocalLLMTaskEntity(
ai_task.AITaskEntity, ai_task.AITaskEntity,
LocalLLMEntity, LocalLLMEntity,
): ):
"""Ollama AI Task entity.""" """AI Task entity."""
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
"""Initialize Ollama AI Task entity.""" """Initialize AI Task entity."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.client._supports_vision(self.runtime_options): if self.client._supports_vision(self.runtime_options):
self._attr_supported_features = ( self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA | ai_task.AITaskEntityFeature.GENERATE_DATA
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
) )
else: else:
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _generate_once(
self,
message_history: list[conversation.Content],
llm_api: llm.APIInstance | None,
entity_options: dict[str, Any],
) -> tuple[str, list | None, Exception | None]:
"""Generate a single response from the LLM."""
collected_tools = None
text = ""
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
try:
if hasattr(self.client, "_generate_stream"):
async for chunk in self.client._generate_stream(
message_history,
llm_api,
self.entity_id,
entity_options,
):
if chunk.response:
text += chunk.response.strip()
if chunk.tool_calls:
collected_tools = chunk.tool_calls
else:
blocking_result = await self.client._generate(
message_history,
llm_api,
self.entity_id,
entity_options,
)
if blocking_result.response:
text = blocking_result.response.strip()
if blocking_result.tool_calls:
collected_tools = blocking_result.tool_calls
_LOGGER.debug("AI Task '%s' generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
return text, collected_tools, None
except JSONDecodeError as err:
_LOGGER.debug("AI Task '%s' json error generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
return text, collected_tools, err
def _extract_data(
self,
raw_text: str,
tool_calls: list[llm.ToolInput] | None,
extraction_method: ResultExtractionMethod,
chat_log: conversation.ChatLog,
structure: vol.Schema | None,
) -> tuple[ai_task.GenDataTaskResult | None, Exception | None]:
"""Extract the final data from the LLM response based on the extraction method."""
try:
if extraction_method == ResultExtractionMethod.NONE or structure is None:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=raw_text,
), None
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
data = json_loads(raw_text)
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
), None
if extraction_method == ResultExtractionMethod.TOOL:
first_tool = next(iter(tool_calls or []), None)
if not first_tool:
return None, HomeAssistantError("Please produce at least one tool call with the structured response.")
structure(first_tool.tool_args) # validate tool call against vol schema structure
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=first_tool.tool_args,
), None
except vol.Invalid as err:
if isinstance(err, vol.MultipleInvalid):
# combine all error messages into one
error_message = "; ".join(f"Error at '{e.path}': {e.error_message}" for e in err.errors)
else:
error_message = f"Error at '{err.path}': {err.error_message}"
return None, HomeAssistantError(f"Please address the following schema errors: {error_message}")
except JSONDecodeError as err:
return None, HomeAssistantError(f"Please produce properly formatted JSON: {repr(err)}")
raise HomeAssistantError(f"Invalid extraction method for AI Task {extraction_method}")
async def _async_generate_data( async def _async_generate_data(
self, self,
task: ai_task.GenDataTask, task: ai_task.GenDataTask,
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult: ) -> ai_task.GenDataTaskResult:
"""Handle a generate data task.""" """Handle a generate data task."""
raw_task_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
max_attempts = retries + 1
extraction_method = ResultExtractionMethod.NONE entity_options = {**self.runtime_options}
if task.structure: # set up extraction method specifics
try: if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT) _LOGGER.debug("Using structure for AI Task '%s': %s", task.name, task.structure)
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure, custom_serializer=llm.selector_serializer)
message_history = chat_log.content[:]
if not isinstance(message_history[0], conversation.SystemContent):
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
message_history.insert(0, system_prompt)
_LOGGER.debug(f"Generating response for {task.name=}...")
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
assistant_message = await anext(generation_result)
if not isinstance(assistant_message, conversation.AssistantContent):
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
text = assistant_message.content
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
if extraction_method == ResultExtractionMethod.NONE:
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM structured response") from err
elif extraction_method == ResultExtractionMethod.TOOL: elif extraction_method == ResultExtractionMethod.TOOL:
try: chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(task.structure)]).async_get_api_instance(
data = assistant_message.tool_calls[0].tool_args llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
except (IndexError, AttributeError) as err: )
_LOGGER.error(
"Failed to extract tool arguments from response: %s. Response: %s", message_history = list(chat_log.content) if chat_log.content else []
err, task_prompt = self.client._generate_system_prompt(raw_task_prompt, llm_api=chat_log.llm_api, entity_options=entity_options)
text, system_message = conversation.SystemContent(content=task_prompt)
) if message_history and isinstance(message_history[0], conversation.SystemContent):
raise HomeAssistantError("Error with Local LLM tool response") from err message_history[0] = system_message
else: else:
raise ValueError() # should not happen message_history.insert(0, system_message)
return ai_task.GenDataTaskResult( if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
conversation_id=chat_log.conversation_id, message_history.append(
data=data, conversation.UserContent(
content=task.instructions, attachments=task.attachments
)
) )
try:
last_error: Exception | None = None
for attempt in range(max_attempts):
_LOGGER.debug("Generating response for %s (attempt %s/%s)...", task.name, attempt + 1, max_attempts)
text, tool_calls, err = await self._generate_once(message_history, chat_log.llm_api, entity_options)
if err:
last_error = err
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
continue
data, err = self._extract_data(text, tool_calls, extraction_method, chat_log, task.structure)
if err:
last_error = err
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
continue
if data:
return data
except Exception as err: except Exception as err:
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name) _LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
raise last_error or HomeAssistantError(f"AI Task '{task.name}' failed after {max_attempts} attempts")

View File

@@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import (
CONF_REMEMBER_CONVERSATION_TIME_MINUTES, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_GENERIC_OPENAI_PATH, CONF_GENERIC_OPENAI_PATH,
CONF_ENABLE_LEGACY_TOOL_CALLING, CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS, DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_P, DEFAULT_TOP_P,
@@ -110,7 +111,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
models_result = await response.json() models_result = await response.json()
except: except (asyncio.TimeoutError, aiohttp.ClientResponseError):
_LOGGER.exception("Failed to get available models") _LOGGER.exception("Failed to get available models")
return RECOMMENDED_CHAT_MODELS return RECOMMENDED_CHAT_MODELS
@@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient):
conversation: List[conversation.Content], conversation: List[conversation.Content],
llm_api: llm.APIInstance | None, llm_api: llm.APIInstance | None,
agent_id: str, agent_id: str,
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]: entity_options: dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options[CONF_CHAT_MODEL] model_name = entity_options[CONF_CHAT_MODEL]
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P) top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
@@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient):
"messages": messages "messages": messages
} }
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
tools = None tools = None
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools" # "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
# most local backends absolutely butcher any sort of prompt formatting when using tool calling # most local backends absolutely butcher any sort of prompt formatting when using tool calling
@@ -155,7 +168,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
session = async_get_clientsession(self.hass) session = async_get_clientsession(self.hass)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]: async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
response = None response = None
chunk = None chunk = None
try: try:
@@ -175,7 +188,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
break break
if chunk and chunk.strip(): if chunk and chunk.strip():
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id) to_say, tool_calls = self._extract_response(json.loads(chunk))
if to_say or tool_calls: if to_say or tool_calls:
yield to_say, tool_calls yield to_say, tool_calls
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
@@ -183,14 +196,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
except aiohttp.ClientError as err: except aiohttp.ClientError as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]: def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {} request_params = {}
endpoint = "/chat/completions" endpoint = "/chat/completions"
return endpoint, request_params return endpoint, request_params
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]: def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s", _LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
list(response_json.keys()), response_json) list(response_json.keys()), response_json)
@@ -204,16 +217,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
elif response_json["object"] == "chat.completion.chunk": elif response_json["object"] == "chat.completion.chunk":
response_text = choice["delta"].get("content", "") response_text = choice["delta"].get("content", "")
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None: if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
tool_calls = [] tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
for call in choice["delta"]["tool_calls"]:
tool_call, to_say = parse_raw_tool_call(
call["function"], llm_api, agent_id)
if tool_call:
tool_calls.append(tool_call)
if to_say:
response_text += to_say
streamed = True streamed = True
else: else:
response_text = choice["text"] response_text = choice["text"]
@@ -267,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try: try:
if msg.role == "user": if msg.role == "user":
input_text = msg.content input_text = msg.content
break
except Exception: except Exception:
continue continue
@@ -367,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
conversation: List[conversation.Content], conversation: List[conversation.Content],
llm_api: llm.APIInstance | None, llm_api: llm.APIInstance | None,
agent_id: str, agent_id: str,
entity_options: dict[str, Any]) -> TextGenerationResult: entity_options: dict[str, Any],
) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream).""" """Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
model_name = entity_options.get(CONF_CHAT_MODEL) model_name = entity_options.get(CONF_CHAT_MODEL)
@@ -377,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
request_params: Dict[str, Any] = { request_params: Dict[str, Any] = {
"model": model_name, "model": model_name,
} }
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
request_params.update(additional_params) request_params.update(additional_params)
headers: Dict[str, Any] = {} headers: Dict[str, Any] = {}
@@ -398,7 +412,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try: try:
text = self._extract_response(response_json) text = self._extract_response(response_json)
return TextGenerationResult(response=text, response_streamed=False) if not text:
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
except Exception as err: except Exception as err:
_LOGGER.exception("Failed to parse Responses API payload: %s", err) _LOGGER.exception("Failed to parse Responses API payload: %s", err)
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}") return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")

View File

@@ -9,7 +9,6 @@ import time
from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast
from homeassistant.components import conversation as conversation from homeassistant.components import conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
@@ -57,6 +56,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_LLAMACPP_THREAD_COUNT, DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DOMAIN, DOMAIN,
CONF_RESPONSE_JSON_SCHEMA,
) )
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
@@ -64,10 +64,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from types import ModuleType from types import ModuleType
if TYPE_CHECKING: if TYPE_CHECKING:
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType from llama_cpp import (
Llama as LlamaType,
LlamaGrammar as LlamaGrammarType,
ChatCompletionRequestResponseFormat
)
else: else:
LlamaType = Any LlamaType = Any
LlamaGrammarType = Any LlamaGrammarType = Any
ChatCompletionRequestResponseFormat = Any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -283,8 +288,6 @@ class LlamaCppClient(LocalLLMClient):
# Sort the items based on the sort_key function # Sort the items based on the sort_key function
sorted_items = sorted(list(entity_order.items()), key=sort_key) sorted_items = sorted(list(entity_order.items()), key=sort_key)
_LOGGER.debug(f"sorted_items: {sorted_items}")
sorted_entities: dict[str, dict[str, str]] = {} sorted_entities: dict[str, dict[str, str]] = {}
for item_name, _ in sorted_items: for item_name, _ in sorted_items:
sorted_entities[item_name] = entities[item_name] sorted_entities[item_name] = entities[item_name]
@@ -297,7 +300,7 @@ class LlamaCppClient(LocalLLMClient):
entity_ids = [ entity_ids = [
state.entity_id for state in self.hass.states.async_all() \ state.entity_id for state in self.hass.states.async_all() \
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id) if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
] ]
_LOGGER.debug(f"watching entities: {entity_ids}") _LOGGER.debug(f"watching entities: {entity_ids}")
@@ -434,13 +437,21 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug(f"Options: {entity_options}") _LOGGER.debug(f"Options: {entity_options}")
messages = get_oai_formatted_messages(conversation, user_content_as_list=True) messages = get_oai_formatted_messages(conversation)
tools = None tools = None
if llm_api: if llm_api:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains()) tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...") _LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
response_format: Optional[ChatCompletionRequestResponseFormat] = None
if response_json_schema:
response_format = {
"type": "json_object",
"schema": response_json_schema,
}
chat_completion = self.models[model_name].create_chat_completion( chat_completion = self.models[model_name].create_chat_completion(
messages, messages,
tools=tools, tools=tools,
@@ -452,6 +463,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=max_tokens, max_tokens=max_tokens,
grammar=grammar, grammar=grammar,
stream=True, stream=True,
response_format=response_format,
) )
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]: def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
@@ -464,5 +476,5 @@ class LlamaCppClient(LocalLLMClient):
tool_calls = chunk["choices"][0]["delta"].get("tool_calls") tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
yield content, tool_calls yield content, tool_calls
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token()) return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())

View File

@@ -1,18 +1,19 @@
"""Defines the ollama compatible agent""" """Defines the Ollama compatible agent backed by the official python client."""
from __future__ import annotations from __future__ import annotations
from warnings import deprecated
import aiohttp
import asyncio
import json
import logging import logging
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator import ssl
from collections.abc import Mapping
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import certifi
import httpx
from ollama import AsyncClient, ChatResponse, ResponseError
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.components import conversation as conversation from homeassistant.components import conversation as conversation
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
@@ -23,119 +24,167 @@ from custom_components.llama_conversation.const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_TYPICAL_P, CONF_TYPICAL_P,
CONF_MIN_P,
CONF_ENABLE_THINK_MODE,
CONF_REQUEST_TIMEOUT, CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY, CONF_OPENAI_API_KEY,
CONF_GENERIC_OPENAI_PATH, CONF_GENERIC_OPENAI_PATH,
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE, CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH, CONF_CONTEXT_LENGTH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS, DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_TOP_P, DEFAULT_TOP_P,
DEFAULT_TYPICAL_P, DEFAULT_TYPICAL_P,
DEFAULT_MIN_P,
DEFAULT_ENABLE_THINK_MODE,
DEFAULT_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT,
DEFAULT_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
) )
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@deprecated("Use the built-in Ollama integration instead")
def _normalize_path(path: str | None) -> str:
if not path:
return ""
trimmed = str(path).strip("/")
return f"/{trimmed}" if trimmed else ""
def _build_default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
try:
context.load_verify_locations(certifi.where())
except OSError as err:
_LOGGER.debug("Failed to load certifi bundle for Ollama client: %s", err)
return context
class OllamaAPIClient(LocalLLMClient): class OllamaAPIClient(LocalLLMClient):
api_host: str api_host: str
api_key: Optional[str] api_key: Optional[str]
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options) super().__init__(hass, client_options)
base_path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
self.api_host = format_url( self.api_host = format_url(
hostname=client_options[CONF_HOST], hostname=client_options[CONF_HOST],
port=client_options[CONF_PORT], port=client_options[CONF_PORT],
ssl=client_options[CONF_SSL], ssl=client_options[CONF_SSL],
path=client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH) path=base_path,
) )
self.api_key = client_options.get(CONF_OPENAI_API_KEY) or None
self._headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else None
self._ssl_context = _build_default_ssl_context() if client_options.get(CONF_SSL) else None
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "") def _build_client(self, *, timeout: float | int | httpx.Timeout | None = None) -> AsyncClient:
timeout_config: httpx.Timeout | float | None = timeout
if isinstance(timeout, (int, float)):
timeout_config = httpx.Timeout(timeout)
return AsyncClient(
host=self.api_host,
headers=self._headers,
timeout=timeout_config,
verify=self._ssl_context,
)
@staticmethod @staticmethod
def get_name(client_options: dict[str, Any]): def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST] host = client_options[CONF_HOST]
port = client_options[CONF_PORT] port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL] ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH] path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'" return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
@staticmethod @staticmethod
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None: async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
headers = {}
api_key = user_input.get(CONF_OPENAI_API_KEY) api_key = user_input.get(CONF_OPENAI_API_KEY)
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH) base_path = _normalize_path(user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
if api_key: timeout_config: httpx.Timeout | float | None = httpx.Timeout(5)
headers["Authorization"] = f"Bearer {api_key}"
verify_context = None
if user_input.get(CONF_SSL):
verify_context = await hass.async_add_executor_job(_build_default_ssl_context)
client = AsyncClient(
host=format_url(
hostname=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
path=base_path,
),
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
timeout=timeout_config,
verify=verify_context,
)
try: try:
session = async_get_clientsession(hass) await client.list()
async with session.get( except httpx.TimeoutException:
format_url( return "Connection timed out"
hostname=user_input[CONF_HOST], except ResponseError as err:
port=user_input[CONF_PORT], return f"HTTP Status {err.status_code}: {err.error}"
ssl=user_input[CONF_SSL], except ConnectionError as err:
path=f"/{api_base_path}/api/tags" return str(err)
),
timeout=aiohttp.ClientTimeout(total=5), # quick timeout return None
headers=headers
) as response:
if response.ok:
return None
else:
return f"HTTP Status {response.status}"
except Exception as ex:
return str(ex)
async def async_get_available_models(self) -> List[str]: async def async_get_available_models(self) -> List[str]:
headers = {} client = self._build_client(timeout=5)
if self.api_key: try:
headers["Authorization"] = f"Bearer {self.api_key}" response = await client.list()
except httpx.TimeoutException as err:
raise HomeAssistantError("Timed out while fetching models from the Ollama server") from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to fetch models from the Ollama server: {err}") from err
session = async_get_clientsession(self.hass) models: List[str] = []
async with session.get( for model in getattr(response, "models", []) or []:
f"{self.api_host}/api/tags", candidate = getattr(model, "name", None) or getattr(model, "model", None)
timeout=aiohttp.ClientTimeout(total=5), # quick timeout if candidate:
headers=headers models.append(candidate)
) as response:
response.raise_for_status()
models_result = await response.json()
return [x["name"] for x in models_result["models"]] return models
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]: def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length content = response_chunk.message.content
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) raw_tool_calls = response_chunk.message.tool_calls
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
if "response" in response_json: if raw_tool_calls:
response = response_json["response"] # return openai formatted tool calls
tool_calls = None tool_calls = [{
stop_reason = None "function": {
if response_json["done"] not in ["true", True]: "name": call.function.name,
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)") "arguments": call.function.arguments,
}
} for call in raw_tool_calls]
else: else:
response = response_json["message"]["content"] tool_calls = None
raw_tool_calls = response_json["message"].get("tool_calls")
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
stop_reason = response_json.get("done_reason")
# _LOGGER.debug(f"{response=} {tool_calls=}") return content, tool_calls
return response, tool_calls @staticmethod
def _format_keep_alive(value: Any) -> Any:
as_text = str(value).strip()
return 0 if as_text in {"0", "0.0"} else f"{as_text}m"
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]: def _generate_stream(
self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options.get(CONF_CHAT_MODEL, "") model_name = entity_options.get(CONF_CHAT_MODEL, "")
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -145,58 +194,48 @@ class OllamaAPIClient(LocalLLMClient):
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT) timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN) keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
think_mode = entity_options.get(CONF_ENABLE_THINK_MODE, DEFAULT_ENABLE_THINK_MODE)
json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE) json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = { options = {
"model": model_name, "num_ctx": context_length,
"stream": True, "top_p": top_p,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model "top_k": top_k,
"options": { "typical_p": typical_p,
"num_ctx": context_length, "temperature": temperature,
"top_p": top_p, "num_predict": max_tokens,
"top_k": top_k, "min_p": entity_options.get(CONF_MIN_P, DEFAULT_MIN_P),
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
},
} }
if json_mode: messages = get_oai_formatted_messages(conversation, tool_args_to_str=False)
request_params["format"] = "json" tools = None
if llm_api and not legacy_tool_calling:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
keep_alive_payload = self._format_keep_alive(keep_alive)
if llm_api: async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains()) client = self._build_client(timeout=timeout)
endpoint = "/api/chat"
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
response = None
chunk = None
try: try:
async with session.post( format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None)
f"{self.api_host}{endpoint}", stream = await client.chat(
json=request_params, model=model_name,
timeout=aiohttp.ClientTimeout(total=timeout), messages=messages,
headers=headers tools=tools,
) as response: stream=True,
response.raise_for_status() think=think_mode,
format=format_option,
while True: options=options,
chunk = await response.content.readline() keep_alive=keep_alive_payload,
if not chunk: )
break
async for chunk in stream:
yield self._extract_response(json.loads(chunk)) yield self._extract_response(chunk)
except asyncio.TimeoutError as err: except httpx.TimeoutException as err:
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err raise HomeAssistantError(
except aiohttp.ClientError as err: "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
) from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token()) return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())

View File

@@ -10,6 +10,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
from homeassistant.components import conversation, ai_task
from homeassistant.data_entry_flow import ( from homeassistant.data_entry_flow import (
AbortFlow, AbortFlow,
) )
@@ -46,6 +47,11 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
@@ -150,7 +156,7 @@ from .const import (
DEFAULT_OPTIONS, DEFAULT_OPTIONS,
option_overrides, option_overrides,
RECOMMENDED_CHAT_MODELS, RECOMMENDED_CHAT_MODELS,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
) )
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
@@ -225,7 +231,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Handle a config flow for Local LLM Conversation.""" """Handle a config flow for Local LLM Conversation."""
VERSION = 3 VERSION = 3
MINOR_VERSION = 1 MINOR_VERSION = 2
install_wheel_task = None install_wheel_task = None
install_wheel_error = None install_wheel_error = None
@@ -399,8 +405,8 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
) -> dict[str, type[ConfigSubentryFlow]]: ) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration.""" """Return subentries supported by this integration."""
return { return {
"conversation": LocalLLMSubentryFlowHandler, conversation.DOMAIN: LocalLLMSubentryFlowHandler,
# "ai_task_data": LocalLLMSubentryFlowHandler, ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
} }
@@ -583,40 +589,13 @@ def local_llama_config_option_schema(
backend_type: str, backend_type: str,
subentry_type: str, subentry_type: str,
) -> dict: ) -> dict:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
result: dict = { result: dict = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_TEMPERATURE, CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)}, description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE), default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)), ): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required( vol.Required(
CONF_THINKING_PREFIX, CONF_THINKING_PREFIX,
description={"suggested_value": options.get(CONF_THINKING_PREFIX)}, description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
@@ -644,9 +623,114 @@ def local_llama_config_option_schema(
): bool, ): bool,
} }
if backend_type == BACKEND_TYPE_LLAMA_CPP: if subentry_type == ai_task.DOMAIN:
result.update({ result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)},
default=options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT),
): TemplateSelector(),
vol.Required(
CONF_AI_TASK_EXTRACTION_METHOD,
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
): SelectSelector(SelectSelectorConfig(
options=[
SelectOptionDict(value="none", label="None"),
SelectOptionDict(value="structure", label="Structured output"),
SelectOptionDict(value="tool", label="Tool call"),
],
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(
CONF_AI_TASK_RETRIES,
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
})
elif subentry_type == conversation.DOMAIN:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
]
result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default=None,
): SelectSelector(SelectSelectorConfig(options=apis, multiple=True)),
vol.Optional(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
if backend_type == BACKEND_TYPE_LLAMA_CPP:
if subentry_type == conversation.DOMAIN:
result.update({
vol.Required( vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
default=DEFAULT_PROMPT_CACHING_ENABLED,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_PROMPT_CACHING_INTERVAL,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
default=DEFAULT_PROMPT_CACHING_INTERVAL,
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
})
result.update({
vol.Required(
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)}, description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS, default=DEFAULT_MAX_TOKENS,
@@ -671,16 +755,6 @@ def local_llama_config_option_schema(
description={"suggested_value": options.get(CONF_TYPICAL_P)}, description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P, default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
default=DEFAULT_PROMPT_CACHING_ENABLED,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_PROMPT_CACHING_INTERVAL,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
default=DEFAULT_PROMPT_CACHING_INTERVAL,
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
# TODO: add rope_scaling_type # TODO: add rope_scaling_type
vol.Required( vol.Required(
CONF_CONTEXT_LENGTH, CONF_CONTEXT_LENGTH,
@@ -879,60 +953,13 @@ def local_llama_config_option_schema(
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)), ): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
}) })
if subentry_type == "conversation":
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
result.update({
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
elif subentry_type == "ai_task_data":
pass # no additional options for ai_task_data for now
# sort the options # sort the options
global_order = [ global_order = [
# general # general
CONF_LLM_HASS_API, CONF_LLM_HASS_API,
CONF_PROMPT, CONF_PROMPT,
CONF_AI_TASK_EXTRACTION_METHOD,
CONF_AI_TASK_RETRIES,
CONF_CONTEXT_LENGTH, CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
# sampling parameters # sampling parameters
@@ -1122,8 +1149,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
description_placeholders = {} description_placeholders = {}
entry = self._get_entry() entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE] backend_type = entry.data[CONF_BACKEND_TYPE]
is_ai_task = self._subentry_type == ai_task.DOMAIN
if CONF_PROMPT not in self.model_config: if is_ai_task:
if CONF_PROMPT not in self.model_config:
self.model_config[CONF_PROMPT] = DEFAULT_AI_TASK_PROMPT
if CONF_AI_TASK_RETRIES not in self.model_config:
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
elif CONF_PROMPT not in self.model_config:
# determine selected language from model config or parent options # determine selected language from model config or parent options
selected_language = self.model_config.get( selected_language = self.model_config.get(
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en") CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
@@ -1156,20 +1191,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
) )
if user_input: if user_input:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED): if not is_ai_task:
errors["base"] = "sys_refresh_caching_enabled" if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if user_input.get(CONF_USE_GBNF_GRAMMAR): if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE) filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file" errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES): if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE) filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file" errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename description_placeholders["filename"] = filename
# --- Normalize numeric fields to ints to avoid slice/type errors later --- # --- Normalize numeric fields to ints to avoid slice/type errors later ---
for key in ( for key in (
@@ -1178,6 +1214,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
CONF_CONTEXT_LENGTH, CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_REQUEST_TIMEOUT, CONF_REQUEST_TIMEOUT,
CONF_AI_TASK_RETRIES,
): ):
if key in user_input: if key in user_input:
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0) user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
@@ -1187,10 +1224,6 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
# validate input # validate input
schema(user_input) schema(user_input)
self.model_config.update(user_input) self.model_config.update(user_input)
# clear LLM API if 'none' selected
if self.model_config.get(CONF_LLM_HASS_API) == "none":
self.model_config.pop(CONF_LLM_HASS_API, None)
return await self.async_step_finish() return await self.async_step_finish()
except Exception: except Exception:

View File

@@ -8,6 +8,11 @@ SERVICE_TOOL_NAME = "HassCallService"
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"] SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"] SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data."
CONF_AI_TASK_RETRIES = "ai_task_retries"
DEFAULT_AI_TASK_RETRIES = 1
CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method"
DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure"
PERSONA_PROMPTS = { PERSONA_PROMPTS = {
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.", "en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
"de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.", "de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.",
@@ -104,6 +109,8 @@ CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.1 DEFAULT_TEMPERATURE = 0.1
CONF_REQUEST_TIMEOUT = "request_timeout" CONF_REQUEST_TIMEOUT = "request_timeout"
DEFAULT_REQUEST_TIMEOUT = 90 DEFAULT_REQUEST_TIMEOUT = 90
CONF_ENABLE_THINK_MODE = "enable_think_mode"
DEFAULT_ENABLE_THINK_MODE = False
CONF_BACKEND_TYPE = "model_backend" CONF_BACKEND_TYPE = "model_backend"
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf" BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing" BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
@@ -185,7 +192,8 @@ DEFAULT_GENERIC_OPENAI_PATH = "v1"
CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model" CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
CONF_CONTEXT_LENGTH = "context_length" CONF_CONTEXT_LENGTH = "context_length"
DEFAULT_CONTEXT_LENGTH = 2048 DEFAULT_CONTEXT_LENGTH = 8192
CONF_RESPONSE_JSON_SCHEMA = "response_json_schema"
CONF_LLAMACPP_BATCH_SIZE = "batch_size" CONF_LLAMACPP_BATCH_SIZE = "batch_size"
DEFAULT_LLAMACPP_BATCH_SIZE = 512 DEFAULT_LLAMACPP_BATCH_SIZE = 512
CONF_LLAMACPP_THREAD_COUNT = "n_threads" CONF_LLAMACPP_THREAD_COUNT = "n_threads"

View File

@@ -17,6 +17,7 @@ from custom_components.llama_conversation.utils import MalformedToolCallExceptio
from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry
from .const import ( from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT, CONF_PROMPT,
CONF_REFRESH_SYSTEM_PROMPT, CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_CONVERSATION,
@@ -39,6 +40,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
if subentry.subentry_type != conversation.DOMAIN: if subentry.subentry_type != conversation.DOMAIN:
continue continue
if CONF_CHAT_MODEL not in subentry.data:
_LOGGER.warning("Conversation subentry %s missing required config key %s, You must delete the model and re-create it.", subentry.subentry_id, CONF_CHAT_MODEL)
continue
# create one agent entity per conversation subentry # create one agent entity per conversation subentry
agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data) agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data)

View File

@@ -5,15 +5,16 @@ import csv
import logging import logging
import os import os
import random import random
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator import re
from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, entity_registry as er, llm, \ from homeassistant.helpers import template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr, entity area_registry as ar, device_registry as dr, entity
from homeassistant.util import color from homeassistant.util import color
@@ -41,8 +42,6 @@ from .const import (
DEFAULT_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -184,28 +183,25 @@ class LocalLLMClient:
_LOGGER.debug("Received chunk: %s", input_chunk) _LOGGER.debug("Received chunk: %s", input_chunk)
tool_calls = input_chunk.tool_calls tool_calls = input_chunk.tool_calls
# fix tool calls for the service tool if tool_calls and not chat_log.llm_api:
if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID: raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided")
tool_calls = [
llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args={**tc.tool_args, "service": tc.tool_name}
) for tc in tool_calls
]
yield conversation.AssistantContentDeltaDict( yield conversation.AssistantContentDeltaDict(
content=input_chunk.response, content=input_chunk.response,
tool_calls=tool_calls tool_calls=tool_calls
) )
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator()) return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
async def _async_parse_completion( async def _async_stream_parse_completion(
self, llm_api: llm.APIInstance | None, self,
llm_api: llm.APIInstance | None,
agent_id: str, agent_id: str,
entity_options: Dict[str, Any], entity_options: Dict[str, Any],
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None, next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None, anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
) -> AsyncGenerator[TextGenerationResult, None]: ) -> AsyncGenerator[TextGenerationResult, None]:
"""Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX) think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX) think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX) tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
@@ -236,7 +232,7 @@ class LocalLLMClient:
cur_match_length = 0 cur_match_length = 0
async for chunk in token_generator: async for chunk in token_generator:
# _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}") # _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
tool_calls: Optional[List[str | llm.ToolInput | dict]] tool_calls: Optional[List[str | dict]]
content, tool_calls = chunk content, tool_calls = chunk
if not tool_calls: if not tool_calls:
@@ -289,31 +285,73 @@ class LocalLLMClient:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls") _LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else: else:
for raw_tool_call in tool_calls: for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput): if isinstance(raw_tool_call, str):
parsed_tool_calls.append(raw_tool_call) tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else: else:
if isinstance(raw_tool_call, str): tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
if tool_call: if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call) _LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call) parsed_tool_calls.append(tool_call)
if to_say: if to_say:
result.response = to_say result.response = to_say
if len(parsed_tool_calls) > 0: if len(parsed_tool_calls) > 0:
result.tool_calls = parsed_tool_calls result.tool_calls = parsed_tool_calls
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls): if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result yield result
async def _async_parse_completion(
self,
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
completion: str | dict) -> TextGenerationResult:
"""Parse completion with tool calls from the backend."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX)
tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL)
if isinstance(completion, dict):
completion = str(completion.get("response", ""))
# Remove thinking blocks, and extract tool calls
tool_calls = tool_regex.findall(completion)
completion = think_regex.sub("", completion)
completion = tool_regex.sub("", completion)
to_say = ""
parsed_tool_calls: list[llm.ToolInput] = []
if len(tool_calls) and not llm_api:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
parsed_tool_calls.append(raw_tool_call)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
return TextGenerationResult(
response=completion + (to_say or ""),
tool_calls=parsed_tool_calls,
)
def _async_get_all_exposed_domains(self) -> list[str]: def _async_get_all_exposed_domains(self) -> list[str]:
"""Gather all exposed domains""" """Gather all exposed domains"""
domains = set() domains = set()
for state in self.hass.states.async_all(): for state in self.hass.states.async_all():
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id): if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
domains.add(state.domain) domains.add(state.domain)
return list(domains) return list(domains)
@@ -326,7 +364,7 @@ class LocalLLMClient:
area_registry = ar.async_get(self.hass) area_registry = ar.async_get(self.hass)
for state in self.hass.states.async_all(): for state in self.hass.states.async_all():
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id): if not async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
continue continue
entity = entity_registry.async_get(state.entity_id) entity = entity_registry.async_get(state.entity_id)

View File

@@ -1,16 +1,17 @@
{ {
"domain": "llama_conversation", "domain": "llama_conversation",
"name": "Local LLMs", "name": "Local LLMs",
"version": "0.4.4", "version": "0.4.5",
"codeowners": ["@acon96"], "codeowners": ["@acon96"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation", "ai_task"],
"after_dependencies": ["assist_pipeline", "intent"], "after_dependencies": ["assist_pipeline", "intent"],
"documentation": "https://github.com/acon96/home-llm", "documentation": "https://github.com/acon96/home-llm",
"integration_type": "service", "integration_type": "service",
"iot_class": "local_polling", "iot_class": "local_polling",
"requirements": [ "requirements": [
"huggingface-hub>=0.23.0", "huggingface-hub>=0.23.0",
"webcolors>=24.8.0" "webcolors>=24.8.0",
"ollama>=0.5.1"
] ]
} }

View File

@@ -70,7 +70,7 @@
"model_parameters": { "model_parameters": {
"data": { "data": {
"max_new_tokens": "Maximum tokens to return in response", "max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API", "llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt", "prompt": "System Prompt",
"temperature": "Temperature", "temperature": "Temperature",
"top_k": "Top K", "top_k": "Top K",
@@ -109,7 +109,7 @@
"max_tool_call_iterations": "Maximum Tool Call Attempts" "max_tool_call_iterations": "Maximum Tool Call Attempts"
}, },
"data_description": { "data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
@@ -124,7 +124,7 @@
"reconfigure": { "reconfigure": {
"data": { "data": {
"max_new_tokens": "Maximum tokens to return in response", "max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API", "llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt", "prompt": "System Prompt",
"temperature": "Temperature", "temperature": "Temperature",
"top_k": "Top K", "top_k": "Top K",
@@ -163,7 +163,7 @@
"max_tool_call_iterations": "Maximum Tool Call Attempts" "max_tool_call_iterations": "Maximum Tool Call Attempts"
}, },
"data_description": { "data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
@@ -177,7 +177,7 @@
} }
} }
}, },
"ai_task_data": { "ai_task": {
"initiate_flow": { "initiate_flow": {
"user": "Add AI Task Handler", "user": "Add AI Task Handler",
"reconfigure": "Reconfigure AI Task Handler" "reconfigure": "Reconfigure AI Task Handler"
@@ -246,7 +246,9 @@
"tool_call_prefix": "Tool Call Prefix", "tool_call_prefix": "Tool Call Prefix",
"tool_call_suffix": "Tool Call Suffix", "tool_call_suffix": "Tool Call Suffix",
"enable_legacy_tool_calling": "Enable Legacy Tool Calling", "enable_legacy_tool_calling": "Enable Legacy Tool Calling",
"max_tool_call_iterations": "Maximum Tool Call Attempts" "max_tool_call_iterations": "Maximum Tool Call Attempts",
"ai_task_extraction_method": "Structured Data Extraction Method",
"ai_task_retries": "Retry attempts for structured data extraction"
}, },
"data_description": { "data_description": {
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
@@ -255,7 +257,8 @@
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.", "gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.", "prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.", "enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)." "max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3).",
"ai_task_extraction_method": "Select the method used to extract structured data from the model's response. 'Structured Output' tells the backend to force the model to produce output following the provided JSON Schema; 'Tool Calling' provides a tool to the model that should be called with the appropriate arguments that match the desired output structure."
}, },
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.", "description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
"title": "Configure the selected model" "title": "Configure the selected model"
@@ -263,7 +266,7 @@
"reconfigure": { "reconfigure": {
"data": { "data": {
"max_new_tokens": "Maximum tokens to return in response", "max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API", "llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt", "prompt": "System Prompt",
"temperature": "Temperature", "temperature": "Temperature",
"top_k": "Top K", "top_k": "Top K",

View File

@@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs
from homeassistant.util import color from homeassistant.util import color
from homeassistant.util.package import install_package, is_installed from homeassistant.util.package import install_package, is_installed
from voluptuous_openapi import convert from voluptuous_openapi import convert as convert_to_openapi
from .const import ( from .const import (
DOMAIN, DOMAIN,
@@ -32,7 +32,7 @@ from .const import (
ALLOWED_SERVICE_CALL_ARGUMENTS, ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES, SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS, SERVICE_TOOL_ALLOWED_DOMAINS,
HOME_LLM_API_ID, SERVICE_TOOL_NAME,
) )
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -275,26 +275,30 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe
def format_url(*, hostname: str, port: str, ssl: bool, path: str): def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}" return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]: def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
if llm_api.api.id == HOME_LLM_API_ID: result: List[ChatCompletionTool] = []
result: List[ChatCompletionTool] = [ {
"type": "function", for tool in llm_api.tools:
"function": { # when combining with home assistant llm APIs, it adds a prefix to differentiate tools; compare against the suffix here
"name": tool["name"], if tool.name.endswith(SERVICE_TOOL_NAME):
"description": f"Call the Home Assistant service '{tool['name']}'", result.extend([{
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer) "type": "function",
} "function": {
} for tool in get_home_llm_tools(llm_api, domains) ] "name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
else: "parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
result: List[ChatCompletionTool] = [ { }
"type": "function", } for tool in get_home_llm_tools(llm_api, domains) ])
"function": { else:
"name": tool.name, result.append({
"description": tool.description or "", "type": "function",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer) "function": {
} "name": tool.name,
} for tool in llm_api.tools ] "description": tool.description or "",
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
})
return result return result
@@ -396,41 +400,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
return tools return tools
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]: def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict): if isinstance(raw_block, dict):
parsed_tool_call = raw_block parsed_tool_call = raw_block
else: else:
parsed_tool_call: dict = json.loads(raw_block) parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID: # try to validate either format
schema_to_validate = vol.Schema({ is_services_tool_call = False
vol.Required('service'): str, try:
vol.Required('target_device'): str, base_schema_to_validate = vol.Schema({
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
else:
schema_to_validate = vol.Schema({
vol.Required("name"): str, vol.Required("name"): str,
vol.Required("arguments"): vol.Union(str, dict), vol.Required("arguments"): vol.Union(str, dict),
}) })
base_schema_to_validate(parsed_tool_call)
try:
schema_to_validate(parsed_tool_call)
except vol.Error as ex: except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}") try:
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted") home_llm_schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
home_llm_schema_to_validate(parsed_tool_call)
is_services_tool_call = True
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
# try to fix certain arguments # try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"] args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", "")) tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
if isinstance(args_dict, str): if isinstance(args_dict, str):
if not args_dict.strip(): if not args_dict.strip():

View File

@@ -6,4 +6,5 @@ home-assistant-intents
# testing requirements # testing requirements
pytest pytest
pytest-asyncio pytest-asyncio
pytest-homeassistant-custom-component==0.13.260 # NOTE this must match the version of Home Assistant used for testing
pytest-homeassistant-custom-component==0.13.272

View File

@@ -1,2 +1,3 @@
huggingface-hub>=0.23.0 huggingface-hub>=0.23.0
webcolors>=24.8.0 webcolors>=24.8.0
ollama>=0.5.1

View File

@@ -28,8 +28,8 @@ Start by installing system dependencies:
Then create a Python virtual environment and install all necessary library: Then create a Python virtual environment and install all necessary library:
``` ```
python3 -m venv .generate_data python3 -m venv .generate_data
source ./.generate_data/bin/activate source .generate_data/bin/activate
pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0 pip3 install -r requirements.txt
``` ```
## Generating the dataset from piles ## Generating the dataset from piles

View File

@@ -1,5 +1,5 @@
datasets>=3.2.0 datasets>=3.2.0
webcolors>=1.13 webcolors>=24.8.0
pandas>=2.2.3 pandas>=2.2.3
deep-translator>=1.11.4 deep-translator>=1.11.4
langcodes>=3.5.0 langcodes>=3.5.0

52
docs/AI Tasks.md Normal file
View File

@@ -0,0 +1,52 @@
# Using AI Tasks
The AI Tasks feature allows you to define structured tasks that your local LLM can perform. These tasks can be integrated into Home Assistant automations and scripts, enabling you to generate dynamic content based on specific prompts and instructions.
## Setting up an AI Task Handler
Setting up a task handler is similar to setting up a conversation agent. You can choose to run the model directly within Home Assistant using `llama-cpp-python`, or you can use an external backend like Ollama. See the [Setup Guide](./docs/Setup.md) for detailed instructions on configuring your AI Task handler.
The specific configuration options for AI Tasks are:
| Option Name | Description |
|-----------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|
| Structured Data Extraction Method | Choose how the AI Task should extract structured data from the model's output. Options include `structured_output` and `tool`. |
| Data Extraction Retry Count | The number of times to retry data extraction if the initial attempt fails. Useful when models can produce incorrect tool responses. |
If no structured data extraction method is specified, then the task entity will always return raw text.
## Using an AI Task in a Script or Automation
To use an AI Task in a Home Assistant script or automation, you can utilize the `ai_task.generate_data` action. This action allows you to specify the task name, instructions, and the structure of the expected output. Below is an example of a script that generates a joke about a smart device in your home.
**Device Joke Script:**
```yaml
sequence:
- action: ai_task.generate_data
data:
task_name: Device Joke Generation
instructions: |
Write a funny joke about one of the smart devices in my home.
Here are all of the smart devices I have:
{% for device in states | rejectattr('domain', 'in', ['update', 'event']) -%}
- {{ device.name }} ({{device.domain}})
{% endfor %}
# You MUST set this to your own LLM entity ID if you do not set a default one in HA Settings
# entity_id: ai_task.unsloth_qwen3_0_6b_gguf_unsloth_qwen3_0_6b_gguf
structure:
joke_setup:
description: The beginning of a joke about a smart device in the home
required: true
selector:
text: null
joke_punchline:
description: The punchline of the same joke about the smart device
required: true
selector:
text: null
response_variable: joke_output
- action: notify.persistent_notification
data:
message: |-
{{ joke_output.data.joke_setup }}
...
{{ joke_output.data.joke_punchline }}
alias: Device Joke
description: "Generates a funny joke about one of the smart devices in the home."
```

View File

@@ -4,7 +4,7 @@ datasets>=3.2.0
peft>=0.14.0 peft>=0.14.0
bitsandbytes>=0.45.2 bitsandbytes>=0.45.2
trl>=0.14.0 trl>=0.14.0
webcolors>=1.13 webcolors>=24.8.0
pandas>=2.2.3 pandas>=2.2.3
# flash-attn # flash-attn
sentencepiece>=0.2.0 sentencepiece>=0.2.0

View File

@@ -1,763 +0,0 @@
import json
import logging
import pytest
import jinja2
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT_BASE,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
PROMPT_TEMPLATE_DESCRIPTIONS,
DEFAULT_OPTIONS,
)
from homeassistant.components.conversation import ConversationInput
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API
)
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
_LOGGER = logging.getLogger(__name__)
class WarnDict(dict):
def get(self, _key, _default=None):
if _key in self:
return self[_key]
_LOGGER.warning(f"attempting to get unset dictionary key {_key}")
return _default
class MockConfigEntry:
def __init__(self, entry_id='test_entry_id', data={}, options={}):
self.entry_id = entry_id
self.data = WarnDict(data)
# Use a mutable dict for options in tests
self.options = WarnDict(dict(options))
@pytest.fixture
def config_entry():
yield MockConfigEntry(
data={
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_BACKEND_TYPE: DEFAULT_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE: "/config/models/some-model.q4_k_m.gguf",
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_OPENAI_API_KEY: "OpenAI-API-Key",
CONF_TEXT_GEN_WEBUI_ADMIN_KEY: "Text-Gen-Webui-Admin-Key"
},
options={
**DEFAULT_OPTIONS,
CONF_LLM_HASS_API: LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
}
)
@pytest.fixture
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
entry_mock.return_value = config_entry
llama_instance_mock = MagicMock()
llama_class_mock = MagicMock()
llama_class_mock.return_value = llama_instance_mock
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
install_llama_cpp_python_mock.return_value = True
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
# template_mock.side_affect = lambda template, _: jinja2.Template(template)
generate_mock = llama_instance_mock.generate
generate_mock.return_value = list(range(20))
detokenize_mock = llama_instance_mock.detokenize
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
})).encode()
call_tool_mock.return_value = {"result": "success"}
agent_obj = LlamaCppAgent(
hass,
config_entry
)
all_mocks = {
"llama_class": llama_class_mock,
"tokenize": llama_instance_mock.tokenize,
"generate": generate_mock,
}
yield agent_obj, all_mocks
# TODO: test base llama agent (ICL loading other languages)
async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent: LlamaCppAgent
all_mocks: dict[str, MagicMock]
local_llama_agent, all_mocks = local_llama_agent_fixture
# invoke the conversation agent
conversation_id = "test-conversation"
result = await local_llama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options then apply them
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
local_llama_agent._update_options()
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
# do another turn of the same conversation
result = await local_llama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
@pytest.fixture
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
get_clientsession.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = OllamaAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession.get,
"requests_post": get_clientsession.post
}
yield agent_obj, all_mocks
async def test_ollama_agent(ollama_agent_fixture):
ollama_agent: OllamaAPIAgent
all_mocks: dict[str, MagicMock]
ollama_agent, all_mocks = ollama_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/api/tags",
headers={ "Authorization": "Bearer OpenAI-API-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"created_at": "2023-11-09T21:07:55.186497Z",
"response": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"done": True,
"context": [1, 2, 3],
"total_duration": 4648158584,
"load_duration": 4071084,
"prompt_eval_count": 36,
"prompt_eval_duration": 439038000,
"eval_count": 180,
"eval_duration": 4196918000
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await ollama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/generate",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"prompt": ANY,
"raw": True
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
ollama_agent.entry.options[CONF_TOP_K] = 20
ollama_agent.entry.options[CONF_TOP_P] = 0.9
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
# do another turn of the same conversation
result = await ollama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/chat",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"format": "json",
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"messages": ANY
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
get_clientsession_mock.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = TextGenerationWebuiAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
text_generation_webui_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/v1/internal/model/info",
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
# do another turn of the same conversation and use a preset
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"preset": "Some Preset",
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# change options
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
response_mock.json.return_value = {
"id": "chatcmpl-123",
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
"object": "chat.completions",
"created": 1677652288,
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# do another turn of the same conversation but the chat endpoint
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
call_tool_mock.return_value = {"result": "success"}
agent_obj = GenericOpenAIAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_generic_openai_agent(generic_openai_agent_fixture):
generic_openai_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
generic_openai_agent, all_mocks = generic_openai_agent_fixture
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await generic_openai_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
response_mock.json.return_value = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# do another turn of the same conversation but the chat endpoint
result = await generic_openai_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)

View File

@@ -0,0 +1,144 @@
"""Tests for AI Task extraction behavior."""
from typing import Any, cast
import pytest
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.ai_task import (
LocalLLMTaskEntity,
ResultExtractionMethod,
)
from custom_components.llama_conversation.const import (
CONF_AI_TASK_EXTRACTION_METHOD,
)
from custom_components.llama_conversation.entity import TextGenerationResult
class DummyGenTask:
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
self.name = name
self.instructions = instructions
self.attachments = attachments or []
self.structure = structure
class DummyChatLog:
def __init__(self, content=None):
self.content = content or []
self.conversation_id = "conv-id"
self.llm_api = None
class DummyClient:
def __init__(self, result: TextGenerationResult):
self._result = result
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
return prompt_template
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
return False
async def _generate(self, _messages, _llm_api, _entity_id, _options):
return self._result
class DummyTaskEntity(LocalLLMTaskEntity):
def __init__(self, hass, client, options):
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
self.hass = hass
self.client = client
self._runtime_options = options
self.entity_id = "ai_task.test"
self.entry_id = "entry"
self.subentry_id = "subentry"
self._attr_supported_features = 0
@property
def runtime_options(self):
return self._runtime_options
@pytest.mark.asyncio
async def test_no_extraction_returns_raw_text(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="raw text")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
)
chat_log = DummyChatLog()
task = DummyGenTask()
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == "raw text"
assert chat_log.llm_api is None
@pytest.mark.asyncio
async def test_structured_output_success(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response='{"foo": 1}')),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"foo": 1}
@pytest.mark.asyncio
async def test_structured_output_invalid_json_raises(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="not-json")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_success(hass):
tool_call = llm.ToolInput("submit_response", {"value": 42})
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"value": 42}
assert chat_log.llm_api is not None
@pytest.mark.asyncio
async def test_tool_extraction_missing_tool_args_raises(hass):
class DummyToolCall:
def __init__(self, tool_args=None):
self.tool_args = tool_args
tool_call = DummyToolCall(tool_args=None)
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))

View File

@@ -0,0 +1,105 @@
"""Lightweight smoke tests for backend helpers.
These avoid backend calls and only cover helper utilities to keep the suite green
while the integration evolves. No integration code is modified.
"""
import pytest
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from custom_components.llama_conversation.backends.llamacpp import snapshot_settings
from custom_components.llama_conversation.backends.ollama import OllamaAPIClient, _normalize_path
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_GBNF_GRAMMAR_FILE,
CONF_PROMPT_CACHING_ENABLED,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_PROMPT_CACHING_ENABLED,
CONF_GENERIC_OPENAI_PATH,
)
@pytest.fixture
async def hass_defaults(hass):
return hass
def test_snapshot_settings_defaults():
options = {CONF_CHAT_MODEL: "test-model"}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == DEFAULT_CONTEXT_LENGTH
assert snap[CONF_LLAMACPP_BATCH_SIZE] == DEFAULT_LLAMACPP_BATCH_SIZE
assert snap[CONF_LLAMACPP_THREAD_COUNT] == DEFAULT_LLAMACPP_THREAD_COUNT
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] == DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert snap[CONF_GBNF_GRAMMAR_FILE] == DEFAULT_GBNF_GRAMMAR_FILE
assert snap[CONF_PROMPT_CACHING_ENABLED] == DEFAULT_PROMPT_CACHING_ENABLED
def test_snapshot_settings_overrides():
options = {
CONF_CONTEXT_LENGTH: 4096,
CONF_LLAMACPP_BATCH_SIZE: 64,
CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_GBNF_GRAMMAR_FILE: "custom.gbnf",
CONF_PROMPT_CACHING_ENABLED: True,
}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == 4096
assert snap[CONF_LLAMACPP_BATCH_SIZE] == 64
assert snap[CONF_LLAMACPP_THREAD_COUNT] == 6
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == 3
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] is True
assert snap[CONF_GBNF_GRAMMAR_FILE] == "custom.gbnf"
assert snap[CONF_PROMPT_CACHING_ENABLED] is True
def test_ollama_keep_alive_formatting():
assert OllamaAPIClient._format_keep_alive("0") == 0
assert OllamaAPIClient._format_keep_alive("0.0") == 0
assert OllamaAPIClient._format_keep_alive(5) == "5m"
assert OllamaAPIClient._format_keep_alive("15") == "15m"
def test_generic_openai_name_and_path(hass_defaults):
client = GenericOpenAIAPIClient(
hass_defaults,
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_CHAT_MODEL: "demo",
},
)
name = client.get_name(
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
}
)
assert "Generic OpenAI" in name
assert "localhost" in name
def test_normalize_path_helper():
assert _normalize_path(None) == ""
assert _normalize_path("") == ""
assert _normalize_path("/v1/") == "/v1"
assert _normalize_path("v2") == "/v2"

View File

@@ -1,350 +1,204 @@
"""Config flow option schema tests to ensure options are wired per-backend."""
import pytest import pytest
from unittest.mock import patch, MagicMock
from homeassistant import config_entries, setup from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.const import ( from homeassistant.helpers import llm
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API,
)
from homeassistant.data_entry_flow import FlowResultType
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow from custom_components.llama_conversation.config_flow import local_llama_config_option_schema
from custom_components.llama_conversation.const import ( from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL, BACKEND_TYPE_LLAMA_CPP,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_TOOL_FORMAT,
CONF_TOOL_MULTI_TURN_CHAT,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER, BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA, BACKEND_TYPE_OLLAMA,
DEFAULT_CHAT_MODEL, CONF_CONTEXT_LENGTH,
DEFAULT_PROMPT, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_MAX_TOKENS, CONF_GBNF_GRAMMAR_FILE,
DEFAULT_TEMPERATURE, CONF_LLAMACPP_BATCH_SIZE,
DEFAULT_TOP_K, CONF_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_TOP_P, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_MIN_P, CONF_LLAMACPP_THREAD_COUNT,
DEFAULT_TYPICAL_P, CONF_MAX_TOKENS,
DEFAULT_BACKEND_TYPE, CONF_MIN_P,
DEFAULT_REQUEST_TIMEOUT, CONF_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_OLLAMA_JSON_MODE,
DEFAULT_PROMPT_TEMPLATE, CONF_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_ENABLE_FLASH_ATTENTION, CONF_PROMPT,
DEFAULT_USE_GBNF_GRAMMAR, CONF_PROMPT_CACHING_ENABLED,
DEFAULT_GBNF_GRAMMAR_FILE, CONF_PROMPT_CACHING_INTERVAL,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_REQUEST_TIMEOUT,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE, CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_THINKING_PREFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_TEMPERATURE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH, DEFAULT_PROMPT,
DEFAULT_BATCH_SIZE, DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_THREAD_COUNT, DEFAULT_REQUEST_TIMEOUT,
DEFAULT_BATCH_THREAD_COUNT, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DOMAIN, DEFAULT_THINKING_PREFIX,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
) )
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
# result = await hass.config_entries.flow.async_init(
# DOMAIN, context={"source": config_entries.SOURCE_USER}
# )
# assert result["type"] == FlowResultType.FORM
# assert result["errors"] is None
# result2 = await hass.config_entries.flow.async_configure( def _schema(hass: HomeAssistant, backend: str, options: dict | None = None):
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF }, return local_llama_config_option_schema(
# ) hass=hass,
# assert result2["type"] == FlowResultType.FORM language="en",
options=options or {},
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry: backend_type=backend,
# result3 = await hass.config_entries.flow.async_configure( subentry_type="conversation",
# result2["flow_id"],
# TEST_DATA,
# )
# await hass.async_block_till_done()
# assert result3["type"] == "create_entry"
# assert result3["title"] == ""
# assert result3["data"] == {
# # ACCOUNT_ID: TEST_DATA["account_id"],
# # CONF_PASSWORD: TEST_DATA["password"],
# # CONNECTION_TYPE: CLOUD,
# }
# assert result3["options"] == {}
# assert len(mock_setup_entry.mock_calls) == 1
@pytest.fixture
def validate_connections_mock():
validate_mock = MagicMock()
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
yield validate_mock
@pytest.fixture
def mock_setup_entry():
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
yield mock_setup_entry
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
) )
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
result3 = await hass.config_entries.flow.async_configure( def _get_default(schema: dict, key_name: str):
result2["flow_id"], for key in schema:
{ if getattr(key, "schema", None) == key_name:
CONF_HOST: "localhost", default = getattr(key, "default", None)
CONF_PORT: "5000", return default() if callable(default) else default
CONF_SSL: False, raise AssertionError(f"Key {key_name} not found in schema")
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = { def _get_suggested(schema: dict, key_name: str):
CONF_PROMPT: DEFAULT_PROMPT, for key in schema:
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, if getattr(key, "schema", None) == key_name:
CONF_TOP_P: DEFAULT_TOP_P, return (getattr(key, "description", {}) or {}).get("suggested_value")
CONF_TEMPERATURE: DEFAULT_TEMPERATURE, raise AssertionError(f"Key {key_name} not found in schema")
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, def test_schema_llama_cpp_defaults_and_overrides(hass: HomeAssistant):
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT, overrides = {
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION, CONF_CONTEXT_LENGTH: 4096,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS, CONF_LLAMACPP_BATCH_SIZE: 8,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX, CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT, CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE, CONF_PROMPT_CACHING_INTERVAL: 15,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES, CONF_TOP_K: 12,
CONF_TOOL_CALL_PREFIX: "<tc>",
} }
result4 = await hass.config_entries.flow.async_configure( schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP, overrides)
result2["flow_id"], options_dict
)
await hass.async_block_till_done()
assert result4["type"] == "create_entry" expected_keys = {
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)" CONF_MAX_TOKENS,
assert result4["data"] == { CONF_CONTEXT_LENGTH,
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI, CONF_TOP_K,
CONF_HOST: "localhost", CONF_TOP_P,
CONF_PORT: "5000", CONF_MIN_P,
CONF_SSL: False, CONF_TYPICAL_P,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL, CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_GBNF_GRAMMAR_FILE,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
} }
assert result4["options"] == options_dict assert expected_keys.issubset({getattr(k, "schema", None) for k in schema})
assert len(mock_setup_entry.mock_calls) == 1
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock): assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
result = await hass.config_entries.flow.async_init( assert _get_default(schema, CONF_LLAMACPP_BATCH_SIZE) == DEFAULT_LLAMACPP_BATCH_SIZE
DOMAIN, context={"source": config_entries.SOURCE_USER} assert _get_default(schema, CONF_LLAMACPP_THREAD_COUNT) == DEFAULT_LLAMACPP_THREAD_COUNT
) assert _get_default(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert result["type"] == FlowResultType.FORM assert _get_default(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert result["errors"] == {} assert _get_default(schema, CONF_PROMPT_CACHING_INTERVAL) == DEFAULT_PROMPT_CACHING_INTERVAL
assert result["step_id"] == "pick_backend" # suggested values should reflect overrides
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 4096
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_SIZE) == 8
assert _get_suggested(schema, CONF_LLAMACPP_THREAD_COUNT) == 6
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == 3
assert _get_suggested(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is True
assert _get_suggested(schema, CONF_PROMPT_CACHING_INTERVAL) == 15
assert _get_suggested(schema, CONF_TOP_K) == 12
assert _get_suggested(schema, CONF_TOOL_CALL_PREFIX) == "<tc>"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
)
assert result2["type"] == FlowResultType.FORM def test_schema_text_gen_webui_options_preserved(hass: HomeAssistant):
assert result2["errors"] == {} overrides = {
assert result2["step_id"] == "remote_model" CONF_REQUEST_TIMEOUT: 123,
CONF_TEXT_GEN_WEBUI_PRESET: "custom-preset",
# simulate incorrect settings on first try CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
validate_connections_mock.side_effect = [ CONF_CONTEXT_LENGTH: 2048,
("failed_to_connect", Exception("ConnectionError"), []),
(None, None, [])
]
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert len(result3["errors"]) == 1
assert "base" in result3["errors"]
assert result3["step_id"] == "remote_model"
# retry
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TYPICAL_P: DEFAULT_MIN_P,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
} }
result4 = await hass.config_entries.flow.async_configure( schema = _schema(hass, BACKEND_TYPE_TEXT_GEN_WEBUI, overrides)
result2["flow_id"], options_dict
expected = {CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REQUEST_TIMEOUT, CONF_CONTEXT_LENGTH}
assert expected.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 123
assert _get_suggested(schema, CONF_TEXT_GEN_WEBUI_PRESET) == "custom-preset"
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 2048
def test_schema_generic_openai_options_preserved(hass: HomeAssistant):
overrides = {CONF_TOP_P: 0.25, CONF_REQUEST_TIMEOUT: 321}
schema = _schema(hass, BACKEND_TYPE_GENERIC_OPENAI, overrides)
assert {CONF_TOP_P, CONF_REQUEST_TIMEOUT}.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_TOP_P) == DEFAULT_TOP_P
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_suggested(schema, CONF_TOP_P) == 0.25
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 321
# Base prompt options still present
prompt_default = _get_default(schema, CONF_PROMPT)
assert prompt_default is not None and "You are 'Al'" in prompt_default
assert _get_default(schema, CONF_NUM_IN_CONTEXT_EXAMPLES) == DEFAULT_NUM_IN_CONTEXT_EXAMPLES
def test_schema_llama_cpp_server_includes_gbnf(hass: HomeAssistant):
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP_SERVER)
keys = {getattr(k, "schema", None) for k in schema}
assert {CONF_MAX_TOKENS, CONF_TOP_K, CONF_GBNF_GRAMMAR_FILE}.issubset(keys)
assert _get_default(schema, CONF_GBNF_GRAMMAR_FILE) == "output.gbnf"
def test_schema_ollama_defaults_and_overrides(hass: HomeAssistant):
overrides = {CONF_OLLAMA_KEEP_ALIVE_MIN: 5, CONF_CONTEXT_LENGTH: 1024, CONF_TOP_K: 7}
schema = _schema(hass, BACKEND_TYPE_OLLAMA, overrides)
assert {CONF_MAX_TOKENS, CONF_CONTEXT_LENGTH, CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE}.issubset(
{getattr(k, "schema", None) for k in schema}
) )
await hass.async_block_till_done() assert _get_default(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == DEFAULT_OLLAMA_KEEP_ALIVE_MIN
assert _get_default(schema, CONF_OLLAMA_JSON_MODE) is DEFAULT_OLLAMA_JSON_MODE
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_default(schema, CONF_TOP_K) == DEFAULT_TOP_K
assert _get_suggested(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == 5
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 1024
assert _get_suggested(schema, CONF_TOP_K) == 7
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
}
assert result4["options"] == options_dict
mock_setup_entry.assert_called_once()
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui def test_schema_includes_llm_api_selector(monkeypatch, hass: HomeAssistant):
monkeypatch.setattr(
"custom_components.llama_conversation.config_flow.llm.async_get_apis",
lambda _hass: [type("API", (), {"id": "dummy", "name": "Dummy API", "tools": []})()],
)
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP)
def test_validate_options_schema(hass: HomeAssistant): assert _get_default(schema, CONF_LLM_HASS_API) is None
# Base prompt and thinking prefixes use defaults when not overridden
universal_options = [ prompt_default = _get_default(schema, CONF_PROMPT)
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT, assert prompt_default is not None and "You are 'Al'" in prompt_default
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, assert _get_default(schema, CONF_THINKING_PREFIX) == DEFAULT_THINKING_PREFIX
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, assert _get_default(schema, CONF_TOOL_CALL_PREFIX) == DEFAULT_TOOL_CALL_PREFIX
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
assert set(options_generic_openai.keys()) == set(universal_options + [
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])

View File

@@ -0,0 +1,114 @@
"""Tests for LocalLLMAgent async_process."""
import pytest
from contextlib import contextmanager
from homeassistant.components.conversation import ConversationInput, SystemContent, AssistantContent
from homeassistant.const import MATCH_ALL
from custom_components.llama_conversation.conversation import LocalLLMAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_PROMPT,
DOMAIN,
)
class DummyClient:
def __init__(self, hass):
self.hass = hass
self.generated_prompts = []
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
self.generated_prompts.append(prompt_template)
return "rendered-system-prompt"
async def _async_generate(self, conv, agent_id, chat_log, entity_options):
async def gen():
yield AssistantContent(agent_id=agent_id, content="hello from llm")
return gen()
class DummySubentry:
def __init__(self, subentry_id="sub1", title="Test Agent", chat_model="model"):
self.subentry_id = subentry_id
self.title = title
self.subentry_type = DOMAIN
self.data = {CONF_CHAT_MODEL: chat_model}
class DummyEntry:
def __init__(self, entry_id="entry1", options=None, subentry=None, runtime_data=None):
self.entry_id = entry_id
self.options = options or {}
self.subentries = {subentry.subentry_id: subentry}
self.runtime_data = runtime_data
def add_update_listener(self, _cb):
return lambda: None
class FakeChatLog:
def __init__(self):
self.content = []
self.llm_api = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeChatSession:
def __enter__(self):
return {}
def __exit__(self, exc_type, exc, tb):
return False
@pytest.mark.asyncio
async def test_async_process_generates_response(monkeypatch, hass):
client = DummyClient(hass)
subentry = DummySubentry()
entry = DummyEntry(subentry=subentry, runtime_data=client)
# Make entry discoverable through hass data as LocalLLMEntity expects.
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
@contextmanager
def fake_chat_session(_hass, _conversation_id):
yield FakeChatSession()
@contextmanager
def fake_chat_log(_hass, _session, _user_input):
yield FakeChatLog()
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.chat_session.async_get_chat_session",
fake_chat_session,
)
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.conversation.async_get_chat_log",
fake_chat_log,
)
agent = LocalLLMAgent(hass, entry, subentry, client)
result = await agent.async_process(
ConversationInput(
text="turn on the lights",
context=None,
conversation_id="conv-id",
device_id=None,
language="en",
agent_id="agent-1",
)
)
assert result.response.speech["plain"]["speech"] == "hello from llm"
# System prompt should be rendered once when message history is empty.
assert client.generated_prompts == [DEFAULT_PROMPT]
assert agent.supported_languages == MATCH_ALL

View File

@@ -0,0 +1,162 @@
"""Tests for LocalLLMClient helpers in entity.py."""
import inspect
import json
import pytest
from json import JSONDecodeError
from custom_components.llama_conversation.entity import LocalLLMClient
from custom_components.llama_conversation.const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_THINKING_PREFIX,
DEFAULT_THINKING_SUFFIX,
)
class DummyLocalClient(LocalLLMClient):
@staticmethod
def get_name(_client_options):
return "dummy"
class DummyLLMApi:
def __init__(self):
self.tools = []
@pytest.fixture
def client(hass):
# Disable ICL loading during tests to avoid filesystem access.
return DummyLocalClient(hass, {CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False})
@pytest.mark.asyncio
async def test_async_parse_completion_parses_tool_call(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":0.5,"to_say":" acknowledged"}}'
completion = (
f"{DEFAULT_THINKING_PREFIX}internal{DEFAULT_THINKING_SUFFIX}"
f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
)
result = await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, completion)
assert result.response.strip().startswith("hello")
assert "acknowledged" in result.response
assert result.tool_calls
tool_call = result.tool_calls[0]
assert tool_call.tool_name == "light.turn_on"
assert tool_call.tool_args["brightness"] == 127
@pytest.mark.asyncio
async def test_async_parse_completion_ignores_tools_without_llm_api(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":1}}'
completion = f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
result = await client._async_parse_completion(None, "agent-id", {}, completion)
assert result.tool_calls == []
assert result.response.strip() == "hello"
@pytest.mark.asyncio
async def test_async_parse_completion_malformed_tool_raises(client):
bad_tool = f"{DEFAULT_TOOL_CALL_PREFIX}{{not-json{DEFAULT_TOOL_CALL_SUFFIX}"
with pytest.raises(JSONDecodeError):
await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, bad_tool)
@pytest.mark.asyncio
async def test_async_stream_parse_completion_handles_streamed_tool_call(client):
async def token_generator():
yield ("Hi", None)
yield (
None,
[
{
"function": {
"name": "light.turn_on",
"arguments": {"brightness": 0.25, "to_say": " ok"},
}
}
],
)
stream = client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
)
results = [chunk async for chunk in stream]
assert results[0].response == "Hi"
assert results[1].response.strip() == "ok"
assert results[1].tool_calls[0].tool_args["brightness"] == 63
@pytest.mark.asyncio
async def test_async_stream_parse_completion_malformed_tool_raises(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{not-json"])
with pytest.raises(JSONDecodeError):
async for _chunk in client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
):
pass
@pytest.mark.asyncio
async def test_async_stream_parse_completion_ignores_tools_without_llm_api(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{}"])
results = [chunk async for chunk in client._async_stream_parse_completion(
None, "agent-id", {}, anext_token=token_generator()
)]
assert results[0].response == "Hi"
assert results[1].tool_calls is None
@pytest.mark.asyncio
async def test_async_get_exposed_entities_respects_exposure(monkeypatch, client, hass):
hass.states.async_set("light.exposed", "on", {"friendly_name": "Lamp"})
hass.states.async_set("switch.hidden", "off", {"friendly_name": "Hidden"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, entity_id: not entity_id.endswith("hidden"),
)
exposed = client._async_get_exposed_entities()
assert "light.exposed" in exposed
assert "switch.hidden" not in exposed
assert exposed["light.exposed"]["friendly_name"] == "Lamp"
assert exposed["light.exposed"]["state"] == "on"
@pytest.mark.asyncio
async def test_generate_system_prompt_renders(monkeypatch, client, hass):
hass.states.async_set("light.kitchen", "on", {"friendly_name": "Kitchen"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, _entity_id: True,
)
rendered = client._generate_system_prompt(
"Devices:\n{{ formatted_devices }}",
llm_api=None,
entity_options={CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: []},
)
if inspect.iscoroutine(rendered):
rendered = await rendered
assert isinstance(rendered, str)
assert "light.kitchen" in rendered

View File

@@ -0,0 +1,159 @@
"""Regression tests for config entry migration in __init__.py."""
import pytest
from homeassistant.const import CONF_LLM_HASS_API, CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.config_entries import ConfigSubentry
from pytest_homeassistant_custom_component.common import MockConfigEntry
from custom_components.llama_conversation import async_migrate_entry
from custom_components.llama_conversation.const import (
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_SERVER,
CONF_BACKEND_TYPE,
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_DOWNLOADED_MODEL_FILE,
CONF_DOWNLOADED_MODEL_QUANTIZATION,
CONF_GENERIC_OPENAI_PATH,
CONF_PROMPT,
CONF_REQUEST_TIMEOUT,
DOMAIN,
)
@pytest.mark.asyncio
async def test_migrate_v1_is_rejected(hass):
entry = MockConfigEntry(domain=DOMAIN, data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP}, version=1)
entry.add_to_hass(hass)
result = await async_migrate_entry(hass, entry)
assert result is False
@pytest.mark.asyncio
async def test_migrate_v2_creates_subentry_and_updates_entry(monkeypatch, hass):
entry = MockConfigEntry(
domain=DOMAIN,
title="llama 'Test Agent' entry",
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_PROMPT: "hello",
CONF_REQUEST_TIMEOUT: 90,
CONF_CHAT_MODEL: "model-x",
CONF_CONTEXT_LENGTH: 1024,
},
version=2,
)
entry.add_to_hass(hass)
added_subentries = []
update_calls = []
def fake_add_subentry(cfg_entry, subentry):
added_subentries.append((cfg_entry, subentry))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(hass.config_entries, "async_add_subentry", fake_add_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert added_subentries, "Subentry should be added"
subentry = added_subentries[0][1]
assert isinstance(subentry, ConfigSubentry)
assert subentry.subentry_type == "conversation"
assert subentry.data[CONF_CHAT_MODEL] == "model-x"
# Entry should be updated to version 3 with data/options separated
assert any(call.get("version") == 3 for call in update_calls)
last_options = [c["options"] for c in update_calls if "options" in c][-1]
assert last_options[CONF_HOST] == "localhost"
assert CONF_PROMPT not in last_options # moved to subentry
@pytest.mark.asyncio
async def test_migrate_v3_minor0_downloads_model(monkeypatch, hass):
sub_data = {
CONF_CHAT_MODEL: "model-a",
CONF_DOWNLOADED_MODEL_QUANTIZATION: "Q4_K_M",
CONF_REQUEST_TIMEOUT: 30,
}
subentry = ConfigSubentry(data=sub_data, subentry_type="conversation", title="sub", unique_id=None)
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP},
options={},
version=3,
minor_version=0,
)
entry.subentries = {"sub": subentry}
entry.add_to_hass(hass)
updated_subentries = []
update_calls = []
def fake_update_subentry(cfg_entry, old_sub, *, data=None, **_kwargs):
updated_subentries.append((cfg_entry, old_sub, data))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(
"custom_components.llama_conversation.download_model_from_hf", lambda *_args, **_kw: "file.gguf"
)
monkeypatch.setattr(hass.config_entries, "async_update_subentry", fake_update_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert updated_subentries, "Subentry should be updated with downloaded file"
new_data = updated_subentries[0][2]
assert new_data[CONF_DOWNLOADED_MODEL_FILE] == "file.gguf"
assert any(call.get("minor_version") == 1 for call in update_calls)
@pytest.mark.parametrize(
"api_value,expected_list",
[("api-1", ["api-1"]), (None, [])],
)
@pytest.mark.asyncio
async def test_migrate_v3_minor1_converts_api_to_list(monkeypatch, hass, api_value, expected_list):
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={CONF_LLM_HASS_API: api_value},
version=3,
minor_version=1,
)
entry.add_to_hass(hass)
calls = []
def fake_update_entry(cfg_entry, **kwargs):
calls.append(kwargs)
if "options" in kwargs:
cfg_entry._options = kwargs["options"] # type: ignore[attr-defined]
if "minor_version" in kwargs:
cfg_entry._minor_version = kwargs["minor_version"] # type: ignore[attr-defined]
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
options_calls = [c for c in calls if "options" in c]
assert options_calls, "async_update_entry should be called with options"
assert options_calls[-1]["options"][CONF_LLM_HASS_API] == expected_list
minor_calls = [c for c in calls if c.get("minor_version")]
assert minor_calls and minor_calls[-1]["minor_version"] == 2