Merge branch 'develop' into feature/dataset-new-apis

This commit is contained in:
Alex O'Connell
2025-12-14 20:43:52 -05:00
27 changed files with 1638 additions and 1493 deletions

3
.gitignore vendored
View File

@@ -11,4 +11,5 @@ main.log
*.xlsx
notes.txt
runpod_bootstrap.sh
*.code-workspace
*.code-workspace
.coverage

View File

@@ -4,16 +4,17 @@ This project provides the required "glue" components to control your Home Assist
## Quick Start
Please see the [Setup Guide](./docs/Setup.md) for more information on installation.
## Local LLM Conversation Integration
## Local LLM Integration
**The latest version of this integration requires Home Assistant 2025.7.0 or newer**
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent".
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent" or as an "ai task handler".
This component can be interacted with in a few ways:
- using a chat interface so you can chat with it.
- integrating with Speech-to-Text and Text-to-Speech addons so you can just speak to it.
- using automations or scripts to trigger "ai tasks"; these process input data with a prompt, and return structured data that can be used in further automations.
The integration can either run the model in 2 different ways:
The integration can either run the model in a few ways:
1. Directly as part of the Home Assistant software using llama-cpp-python
2. On a separate machine using one of the following backends:
- [Ollama](https://ollama.com/) (easier)
@@ -36,6 +37,7 @@ The latest models can be found on HuggingFace:
**Gemma3**:
1B: TBD
270M: TBD
<details>
@@ -158,6 +160,7 @@ python3 train.py \
## Version History
| Version | Description |
|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.4.5 | Add support for AI Task entities, Replace custom Ollama API implementation with the official `ollama-python` package to avoid future compatibility issues, Support multiple LLM APIs at once, Fix issues in tool call handling for various backends |
| v0.4.4 | Fix issue with OpenAI backends appending `/v1` to all URLs, and fix an issue with tools being serialized into the system prompt. |
| v0.4.3 | Fix an issue with the integration not creating model configs properly during setup |
| v0.4.2 | Fix the following issues: not correctly setting default model settings during initial setup, non-integers being allowed in numeric config fields, being too strict with finish_reason requirements, and not letting the user clear the active LLM API |

View File

@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK)
PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
@@ -184,6 +184,21 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
_LOGGER.debug("Migration to add downloaded model file complete")
if config_entry.version == 3 and config_entry.minor_version == 1:
# convert selected APIs from single value to list
api_to_convert = config_entry.options.get(CONF_LLM_HASS_API)
new_options = dict(config_entry.options)
if api_to_convert is not None:
new_options[CONF_LLM_HASS_API] = [api_to_convert]
else:
new_options[CONF_LLM_HASS_API] = []
hass.config_entries.async_update_entry(
config_entry, options=MappingProxyType(new_options)
)
hass.config_entries.async_update_entry(config_entry, minor_version=2)
return True
class HassServiceTool(llm.Tool):
@@ -255,14 +270,14 @@ class HomeLLMAPI(llm.API):
super().__init__(
hass=hass,
id=HOME_LLM_API_ID,
name="Home-LLM (v1-v3)",
name="Home Assistant Services",
)
async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance:
"""Return the instance of the API."""
return llm.APIInstance(
api=self,
api_prompt="Call services in Home Assistant by passing the service name and the device to control.",
api_prompt="Call services in Home Assistant by passing the service name and the device to control. Designed for Home-LLM Models (v1-v3)",
llm_context=llm_context,
tools=[HassServiceTool()],
)

View File

@@ -1,14 +1,18 @@
"""AI Task integration for Local LLMs."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from enum import StrEnum
from typing import Any
import voluptuous as vol
from voluptuous_openapi import convert as convert_to_openapi
from homeassistant.helpers import llm
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, Context
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
@@ -16,7 +20,13 @@ from homeassistant.util.json import json_loads
from .entity import LocalLLMEntity, LocalLLMClient
from .const import (
CONF_PROMPT,
DEFAULT_PROMPT,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
@@ -29,98 +39,222 @@ async def async_setup_entry(
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
if subentry.subentry_type != ai_task.DOMAIN:
continue
async_add_entities(
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
config_subentry_id=subentry.subentry_id,
)
# create one entity per subentry
ai_task_entity = LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)
# make sure model is loaded
await config_entry.runtime_data._async_load_model(dict(subentry.data))
# register the ai task entity
async_add_entities([ai_task_entity], config_subentry_id=subentry.subentry_id)
class ResultExtractionMethod(StrEnum):
NONE = "none"
STRUCTURED_OUTPUT = "structure"
TOOL = "tool"
class SubmitResponseTool(llm.Tool):
name = "submit_response"
description = "Submit the structured response payload for the AI task"
def __init__(self, parameters_schema: vol.Schema):
self.parameters = parameters_schema
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> dict:
return tool_input.tool_args or {}
class SubmitResponseAPI(llm.API):
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
self._tools = tools
super().__init__(
hass=hass,
id=f"{DOMAIN}-ai-task-tool",
name="AI Task Tool API",
)
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
return llm.APIInstance(
api=self,
api_prompt="Call submit_response to return the structured AI task result.",
llm_context=llm_context,
tools=self._tools,
custom_serializer=llm.selector_serializer,
)
class LocalLLMTaskEntity(
ai_task.AITaskEntity,
LocalLLMEntity,
):
"""Ollama AI Task entity."""
"""AI Task entity."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize Ollama AI Task entity."""
"""Initialize AI Task entity."""
super().__init__(*args, **kwargs)
if self.client._supports_vision(self.runtime_options):
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA |
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
else:
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _generate_once(
self,
message_history: list[conversation.Content],
llm_api: llm.APIInstance | None,
entity_options: dict[str, Any],
) -> tuple[str, list | None, Exception | None]:
"""Generate a single response from the LLM."""
collected_tools = None
text = ""
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
try:
if hasattr(self.client, "_generate_stream"):
async for chunk in self.client._generate_stream(
message_history,
llm_api,
self.entity_id,
entity_options,
):
if chunk.response:
text += chunk.response.strip()
if chunk.tool_calls:
collected_tools = chunk.tool_calls
else:
blocking_result = await self.client._generate(
message_history,
llm_api,
self.entity_id,
entity_options,
)
if blocking_result.response:
text = blocking_result.response.strip()
if blocking_result.tool_calls:
collected_tools = blocking_result.tool_calls
_LOGGER.debug("AI Task '%s' generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
return text, collected_tools, None
except JSONDecodeError as err:
_LOGGER.debug("AI Task '%s' json error generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
return text, collected_tools, err
def _extract_data(
self,
raw_text: str,
tool_calls: list[llm.ToolInput] | None,
extraction_method: ResultExtractionMethod,
chat_log: conversation.ChatLog,
structure: vol.Schema | None,
) -> tuple[ai_task.GenDataTaskResult | None, Exception | None]:
"""Extract the final data from the LLM response based on the extraction method."""
try:
if extraction_method == ResultExtractionMethod.NONE or structure is None:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=raw_text,
), None
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
data = json_loads(raw_text)
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
), None
if extraction_method == ResultExtractionMethod.TOOL:
first_tool = next(iter(tool_calls or []), None)
if not first_tool:
return None, HomeAssistantError("Please produce at least one tool call with the structured response.")
structure(first_tool.tool_args) # validate tool call against vol schema structure
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=first_tool.tool_args,
), None
except vol.Invalid as err:
if isinstance(err, vol.MultipleInvalid):
# combine all error messages into one
error_message = "; ".join(f"Error at '{e.path}': {e.error_message}" for e in err.errors)
else:
error_message = f"Error at '{err.path}': {err.error_message}"
return None, HomeAssistantError(f"Please address the following schema errors: {error_message}")
except JSONDecodeError as err:
return None, HomeAssistantError(f"Please produce properly formatted JSON: {repr(err)}")
raise HomeAssistantError(f"Invalid extraction method for AI Task {extraction_method}")
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
raw_task_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
max_attempts = retries + 1
extraction_method = ResultExtractionMethod.NONE
try:
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
message_history = chat_log.content[:]
if not isinstance(message_history[0], conversation.SystemContent):
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
message_history.insert(0, system_prompt)
_LOGGER.debug(f"Generating response for {task.name=}...")
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
assistant_message = await anext(generation_result)
if not isinstance(assistant_message, conversation.AssistantContent):
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
text = assistant_message.content
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
if extraction_method == ResultExtractionMethod.NONE:
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM structured response") from err
entity_options = {**self.runtime_options}
if task.structure: # set up extraction method specifics
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
_LOGGER.debug("Using structure for AI Task '%s': %s", task.name, task.structure)
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure, custom_serializer=llm.selector_serializer)
elif extraction_method == ResultExtractionMethod.TOOL:
try:
data = assistant_message.tool_calls[0].tool_args
except (IndexError, AttributeError) as err:
_LOGGER.error(
"Failed to extract tool arguments from response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Local LLM tool response") from err
else:
raise ValueError() # should not happen
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(task.structure)]).async_get_api_instance(
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
)
message_history = list(chat_log.content) if chat_log.content else []
task_prompt = self.client._generate_system_prompt(raw_task_prompt, llm_api=chat_log.llm_api, entity_options=entity_options)
system_message = conversation.SystemContent(content=task_prompt)
if message_history and isinstance(message_history[0], conversation.SystemContent):
message_history[0] = system_message
else:
message_history.insert(0, system_message)
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
message_history.append(
conversation.UserContent(
content=task.instructions, attachments=task.attachments
)
)
try:
last_error: Exception | None = None
for attempt in range(max_attempts):
_LOGGER.debug("Generating response for %s (attempt %s/%s)...", task.name, attempt + 1, max_attempts)
text, tool_calls, err = await self._generate_once(message_history, chat_log.llm_api, entity_options)
if err:
last_error = err
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
continue
data, err = self._extract_data(text, tool_calls, extraction_method, chat_log, task.structure)
if err:
last_error = err
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
continue
if data:
return data
except Exception as err:
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
raise last_error or HomeAssistantError(f"AI Task '{task.name}' failed after {max_attempts} attempts")

View File

@@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import (
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_GENERIC_OPENAI_PATH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
@@ -110,7 +111,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
) as response:
response.raise_for_status()
models_result = await response.json()
except:
except (asyncio.TimeoutError, aiohttp.ClientResponseError):
_LOGGER.exception("Failed to get available models")
return RECOMMENDED_CHAT_MODELS
@@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient):
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
entity_options: dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options[CONF_CHAT_MODEL]
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
@@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient):
"messages": messages
}
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
tools = None
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
# most local backends absolutely butcher any sort of prompt formatting when using tool calling
@@ -155,7 +168,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
session = async_get_clientsession(self.hass)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
response = None
chunk = None
try:
@@ -175,7 +188,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
break
if chunk and chunk.strip():
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
to_say, tool_calls = self._extract_response(json.loads(chunk))
if to_say or tool_calls:
yield to_say, tool_calls
except asyncio.TimeoutError as err:
@@ -183,14 +196,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
except aiohttp.ClientError as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/chat/completions"
return endpoint, request_params
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
list(response_json.keys()), response_json)
@@ -204,16 +217,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
elif response_json["object"] == "chat.completion.chunk":
response_text = choice["delta"].get("content", "")
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
tool_calls = []
for call in choice["delta"]["tool_calls"]:
tool_call, to_say = parse_raw_tool_call(
call["function"], llm_api, agent_id)
if tool_call:
tool_calls.append(tool_call)
if to_say:
response_text += to_say
tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
streamed = True
else:
response_text = choice["text"]
@@ -267,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try:
if msg.role == "user":
input_text = msg.content
break
except Exception:
continue
@@ -367,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: dict[str, Any]) -> TextGenerationResult:
entity_options: dict[str, Any],
) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
model_name = entity_options.get(CONF_CHAT_MODEL)
@@ -377,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
request_params: Dict[str, Any] = {
"model": model_name,
}
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
if response_json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "ha_task",
"schema": response_json_schema,
"strict": True,
},
}
request_params.update(additional_params)
headers: Dict[str, Any] = {}
@@ -398,7 +412,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
try:
text = self._extract_response(response_json)
return TextGenerationResult(response=text, response_streamed=False)
if not text:
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
except Exception as err:
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")

View File

@@ -9,7 +9,6 @@ import time
from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast
from homeassistant.components import conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
@@ -57,6 +56,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DOMAIN,
CONF_RESPONSE_JSON_SCHEMA,
)
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
@@ -64,10 +64,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
from typing import TYPE_CHECKING
from types import ModuleType
if TYPE_CHECKING:
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType
from llama_cpp import (
Llama as LlamaType,
LlamaGrammar as LlamaGrammarType,
ChatCompletionRequestResponseFormat
)
else:
LlamaType = Any
LlamaGrammarType = Any
ChatCompletionRequestResponseFormat = Any
_LOGGER = logging.getLogger(__name__)
@@ -283,8 +288,6 @@ class LlamaCppClient(LocalLLMClient):
# Sort the items based on the sort_key function
sorted_items = sorted(list(entity_order.items()), key=sort_key)
_LOGGER.debug(f"sorted_items: {sorted_items}")
sorted_entities: dict[str, dict[str, str]] = {}
for item_name, _ in sorted_items:
sorted_entities[item_name] = entities[item_name]
@@ -297,7 +300,7 @@ class LlamaCppClient(LocalLLMClient):
entity_ids = [
state.entity_id for state in self.hass.states.async_all() \
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
_LOGGER.debug(f"watching entities: {entity_ids}")
@@ -434,13 +437,21 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug(f"Options: {entity_options}")
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
messages = get_oai_formatted_messages(conversation)
tools = None
if llm_api:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
response_format: Optional[ChatCompletionRequestResponseFormat] = None
if response_json_schema:
response_format = {
"type": "json_object",
"schema": response_json_schema,
}
chat_completion = self.models[model_name].create_chat_completion(
messages,
tools=tools,
@@ -452,6 +463,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=max_tokens,
grammar=grammar,
stream=True,
response_format=response_format,
)
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
@@ -464,5 +476,5 @@ class LlamaCppClient(LocalLLMClient):
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
yield content, tool_calls
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())

View File

@@ -1,18 +1,19 @@
"""Defines the ollama compatible agent"""
"""Defines the Ollama compatible agent backed by the official python client."""
from __future__ import annotations
from warnings import deprecated
import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
import ssl
from collections.abc import Mapping
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import certifi
import httpx
from ollama import AsyncClient, ChatResponse, ResponseError
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.components import conversation as conversation
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
@@ -23,119 +24,167 @@ from custom_components.llama_conversation.const import (
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_ENABLE_THINK_MODE,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_GENERIC_OPENAI_PATH,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_RESPONSE_JSON_SCHEMA,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
DEFAULT_MIN_P,
DEFAULT_ENABLE_THINK_MODE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_GENERIC_OPENAI_PATH,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
)
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
@deprecated("Use the built-in Ollama integration instead")
def _normalize_path(path: str | None) -> str:
if not path:
return ""
trimmed = str(path).strip("/")
return f"/{trimmed}" if trimmed else ""
def _build_default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
try:
context.load_verify_locations(certifi.where())
except OSError as err:
_LOGGER.debug("Failed to load certifi bundle for Ollama client: %s", err)
return context
class OllamaAPIClient(LocalLLMClient):
api_host: str
api_key: Optional[str]
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options)
base_path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
self.api_host = format_url(
hostname=client_options[CONF_HOST],
port=client_options[CONF_PORT],
ssl=client_options[CONF_SSL],
path=client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
path=base_path,
)
self.api_key = client_options.get(CONF_OPENAI_API_KEY) or None
self._headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else None
self._ssl_context = _build_default_ssl_context() if client_options.get(CONF_SSL) else None
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
def _build_client(self, *, timeout: float | int | httpx.Timeout | None = None) -> AsyncClient:
timeout_config: httpx.Timeout | float | None = timeout
if isinstance(timeout, (int, float)):
timeout_config = httpx.Timeout(timeout)
return AsyncClient(
host=self.api_host,
headers=self._headers,
timeout=timeout_config,
verify=self._ssl_context,
)
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
@staticmethod
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
headers = {}
api_key = user_input.get(CONF_OPENAI_API_KEY)
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
base_path = _normalize_path(user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
timeout_config: httpx.Timeout | float | None = httpx.Timeout(5)
verify_context = None
if user_input.get(CONF_SSL):
verify_context = await hass.async_add_executor_job(_build_default_ssl_context)
client = AsyncClient(
host=format_url(
hostname=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
path=base_path,
),
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
timeout=timeout_config,
verify=verify_context,
)
try:
session = async_get_clientsession(hass)
async with session.get(
format_url(
hostname=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
path=f"/{api_base_path}/api/tags"
),
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
if response.ok:
return None
else:
return f"HTTP Status {response.status}"
except Exception as ex:
return str(ex)
await client.list()
except httpx.TimeoutException:
return "Connection timed out"
except ResponseError as err:
return f"HTTP Status {err.status_code}: {err.error}"
except ConnectionError as err:
return str(err)
return None
async def async_get_available_models(self) -> List[str]:
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
client = self._build_client(timeout=5)
try:
response = await client.list()
except httpx.TimeoutException as err:
raise HomeAssistantError("Timed out while fetching models from the Ollama server") from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to fetch models from the Ollama server: {err}") from err
session = async_get_clientsession(self.hass)
async with session.get(
f"{self.api_host}/api/tags",
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
response.raise_for_status()
models_result = await response.json()
models: List[str] = []
for model in getattr(response, "models", []) or []:
candidate = getattr(model, "name", None) or getattr(model, "model", None)
if candidate:
models.append(candidate)
return [x["name"] for x in models_result["models"]]
return models
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]:
content = response_chunk.message.content
raw_tool_calls = response_chunk.message.tool_calls
if "response" in response_json:
response = response_json["response"]
tool_calls = None
stop_reason = None
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
if raw_tool_calls:
# return openai formatted tool calls
tool_calls = [{
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
}
} for call in raw_tool_calls]
else:
response = response_json["message"]["content"]
raw_tool_calls = response_json["message"].get("tool_calls")
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
stop_reason = response_json.get("done_reason")
tool_calls = None
# _LOGGER.debug(f"{response=} {tool_calls=}")
return content, tool_calls
return response, tool_calls
@staticmethod
def _format_keep_alive(value: Any) -> Any:
as_text = str(value).strip()
return 0 if as_text in {"0", "0.0"} else f"{as_text}m"
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
def _generate_stream(
self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options.get(CONF_CHAT_MODEL, "")
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -145,58 +194,48 @@ class OllamaAPIClient(LocalLLMClient):
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
think_mode = entity_options.get(CONF_ENABLE_THINK_MODE, DEFAULT_ENABLE_THINK_MODE)
json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
},
options = {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
"min_p": entity_options.get(CONF_MIN_P, DEFAULT_MIN_P),
}
if json_mode:
request_params["format"] = "json"
messages = get_oai_formatted_messages(conversation, tool_args_to_str=False)
tools = None
if llm_api and not legacy_tool_calling:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
keep_alive_payload = self._format_keep_alive(keep_alive)
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
endpoint = "/api/chat"
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
response = None
chunk = None
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
client = self._build_client(timeout=timeout)
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=aiohttp.ClientTimeout(total=timeout),
headers=headers
) as response:
response.raise_for_status()
while True:
chunk = await response.content.readline()
if not chunk:
break
yield self._extract_response(json.loads(chunk))
except asyncio.TimeoutError as err:
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
except aiohttp.ClientError as err:
format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None)
stream = await client.chat(
model=model_name,
messages=messages,
tools=tools,
stream=True,
think=think_mode,
format=format_option,
options=options,
keep_alive=keep_alive_payload,
)
async for chunk in stream:
yield self._extract_response(chunk)
except httpx.TimeoutException as err:
raise HomeAssistantError(
"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
) from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())

View File

@@ -10,6 +10,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
from homeassistant.components import conversation, ai_task
from homeassistant.data_entry_flow import (
AbortFlow,
)
@@ -46,6 +47,11 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
@@ -150,7 +156,7 @@ from .const import (
DEFAULT_OPTIONS,
option_overrides,
RECOMMENDED_CHAT_MODELS,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
)
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
@@ -225,7 +231,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Handle a config flow for Local LLM Conversation."""
VERSION = 3
MINOR_VERSION = 1
MINOR_VERSION = 2
install_wheel_task = None
install_wheel_error = None
@@ -399,8 +405,8 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {
"conversation": LocalLLMSubentryFlowHandler,
# "ai_task_data": LocalLLMSubentryFlowHandler,
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
}
@@ -583,40 +589,13 @@ def local_llama_config_option_schema(
backend_type: str,
subentry_type: str,
) -> dict:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
result: dict = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required(
CONF_THINKING_PREFIX,
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
@@ -644,9 +623,114 @@ def local_llama_config_option_schema(
): bool,
}
if backend_type == BACKEND_TYPE_LLAMA_CPP:
if subentry_type == ai_task.DOMAIN:
result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)},
default=options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT),
): TemplateSelector(),
vol.Required(
CONF_AI_TASK_EXTRACTION_METHOD,
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
): SelectSelector(SelectSelectorConfig(
options=[
SelectOptionDict(value="none", label="None"),
SelectOptionDict(value="structure", label="Structured output"),
SelectOptionDict(value="tool", label="Tool call"),
],
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(
CONF_AI_TASK_RETRIES,
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
})
elif subentry_type == conversation.DOMAIN:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
]
result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default=None,
): SelectSelector(SelectSelectorConfig(options=apis, multiple=True)),
vol.Optional(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
if backend_type == BACKEND_TYPE_LLAMA_CPP:
if subentry_type == conversation.DOMAIN:
result.update({
vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
default=DEFAULT_PROMPT_CACHING_ENABLED,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_PROMPT_CACHING_INTERVAL,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
default=DEFAULT_PROMPT_CACHING_INTERVAL,
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
})
result.update({
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
@@ -671,16 +755,6 @@ def local_llama_config_option_schema(
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
default=DEFAULT_PROMPT_CACHING_ENABLED,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_PROMPT_CACHING_INTERVAL,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
default=DEFAULT_PROMPT_CACHING_INTERVAL,
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
# TODO: add rope_scaling_type
vol.Required(
CONF_CONTEXT_LENGTH,
@@ -879,60 +953,13 @@ def local_llama_config_option_schema(
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
})
if subentry_type == "conversation":
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
result.update({
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
elif subentry_type == "ai_task_data":
pass # no additional options for ai_task_data for now
# sort the options
global_order = [
# general
CONF_LLM_HASS_API,
CONF_PROMPT,
CONF_AI_TASK_EXTRACTION_METHOD,
CONF_AI_TASK_RETRIES,
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
# sampling parameters
@@ -1122,8 +1149,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE]
is_ai_task = self._subentry_type == ai_task.DOMAIN
if CONF_PROMPT not in self.model_config:
if is_ai_task:
if CONF_PROMPT not in self.model_config:
self.model_config[CONF_PROMPT] = DEFAULT_AI_TASK_PROMPT
if CONF_AI_TASK_RETRIES not in self.model_config:
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
elif CONF_PROMPT not in self.model_config:
# determine selected language from model config or parent options
selected_language = self.model_config.get(
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
@@ -1156,20 +1191,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
)
if user_input:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if not is_ai_task:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
for key in (
@@ -1178,6 +1214,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
CONF_REQUEST_TIMEOUT,
CONF_AI_TASK_RETRIES,
):
if key in user_input:
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
@@ -1187,10 +1224,6 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
# validate input
schema(user_input)
self.model_config.update(user_input)
# clear LLM API if 'none' selected
if self.model_config.get(CONF_LLM_HASS_API) == "none":
self.model_config.pop(CONF_LLM_HASS_API, None)
return await self.async_step_finish()
except Exception:

View File

@@ -8,6 +8,11 @@ SERVICE_TOOL_NAME = "HassCallService"
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
CONF_PROMPT = "prompt"
DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data."
CONF_AI_TASK_RETRIES = "ai_task_retries"
DEFAULT_AI_TASK_RETRIES = 1
CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method"
DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure"
PERSONA_PROMPTS = {
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
"de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.",
@@ -104,6 +109,8 @@ CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.1
CONF_REQUEST_TIMEOUT = "request_timeout"
DEFAULT_REQUEST_TIMEOUT = 90
CONF_ENABLE_THINK_MODE = "enable_think_mode"
DEFAULT_ENABLE_THINK_MODE = False
CONF_BACKEND_TYPE = "model_backend"
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
@@ -185,7 +192,8 @@ DEFAULT_GENERIC_OPENAI_PATH = "v1"
CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
CONF_CONTEXT_LENGTH = "context_length"
DEFAULT_CONTEXT_LENGTH = 2048
DEFAULT_CONTEXT_LENGTH = 8192
CONF_RESPONSE_JSON_SCHEMA = "response_json_schema"
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
DEFAULT_LLAMACPP_BATCH_SIZE = 512
CONF_LLAMACPP_THREAD_COUNT = "n_threads"

View File

@@ -17,6 +17,7 @@ from custom_components.llama_conversation.utils import MalformedToolCallExceptio
from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
@@ -39,6 +40,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
if subentry.subentry_type != conversation.DOMAIN:
continue
if CONF_CHAT_MODEL not in subentry.data:
_LOGGER.warning("Conversation subentry %s missing required config key %s, You must delete the model and re-create it.", subentry.subentry_id, CONF_CHAT_MODEL)
continue
# create one agent entity per conversation subentry
agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data)

View File

@@ -5,15 +5,16 @@ import csv
import logging
import os
import random
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
import re
from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator
from dataclasses import dataclass
from homeassistant.components import conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr, entity
from homeassistant.util import color
@@ -41,8 +42,6 @@ from .const import (
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
)
_LOGGER = logging.getLogger(__name__)
@@ -184,28 +183,25 @@ class LocalLLMClient:
_LOGGER.debug("Received chunk: %s", input_chunk)
tool_calls = input_chunk.tool_calls
# fix tool calls for the service tool
if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID:
tool_calls = [
llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args={**tc.tool_args, "service": tc.tool_name}
) for tc in tool_calls
]
if tool_calls and not chat_log.llm_api:
raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided")
yield conversation.AssistantContentDeltaDict(
content=input_chunk.response,
tool_calls=tool_calls
tool_calls=tool_calls
)
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
async def _async_parse_completion(
self, llm_api: llm.APIInstance | None,
async def _async_stream_parse_completion(
self,
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
) -> AsyncGenerator[TextGenerationResult, None]:
"""Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
@@ -236,7 +232,7 @@ class LocalLLMClient:
cur_match_length = 0
async for chunk in token_generator:
# _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
tool_calls: Optional[List[str | llm.ToolInput | dict]]
tool_calls: Optional[List[str | dict]]
content, tool_calls = chunk
if not tool_calls:
@@ -289,31 +285,73 @@ class LocalLLMClient:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
parsed_tool_calls.append(raw_tool_call)
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
if to_say:
result.response = to_say
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
if to_say:
result.response = to_say
if len(parsed_tool_calls) > 0:
result.tool_calls = parsed_tool_calls
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result
async def _async_parse_completion(
self,
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
completion: str | dict) -> TextGenerationResult:
"""Parse completion with tool calls from the backend."""
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL)
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX)
tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL)
if isinstance(completion, dict):
completion = str(completion.get("response", ""))
# Remove thinking blocks, and extract tool calls
tool_calls = tool_regex.findall(completion)
completion = think_regex.sub("", completion)
completion = tool_regex.sub("", completion)
to_say = ""
parsed_tool_calls: list[llm.ToolInput] = []
if len(tool_calls) and not llm_api:
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
else:
for raw_tool_call in tool_calls:
if isinstance(raw_tool_call, llm.ToolInput):
parsed_tool_calls.append(raw_tool_call)
else:
if isinstance(raw_tool_call, str):
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
else:
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed: %s", tool_call)
parsed_tool_calls.append(tool_call)
return TextGenerationResult(
response=completion + (to_say or ""),
tool_calls=parsed_tool_calls,
)
def _async_get_all_exposed_domains(self) -> list[str]:
"""Gather all exposed domains"""
domains = set()
for state in self.hass.states.async_all():
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
domains.add(state.domain)
return list(domains)
@@ -326,7 +364,7 @@ class LocalLLMClient:
area_registry = ar.async_get(self.hass)
for state in self.hass.states.async_all():
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
if not async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
continue
entity = entity_registry.async_get(state.entity_id)

View File

@@ -1,16 +1,17 @@
{
"domain": "llama_conversation",
"name": "Local LLMs",
"version": "0.4.4",
"version": "0.4.5",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],
"dependencies": ["conversation", "ai_task"],
"after_dependencies": ["assist_pipeline", "intent"],
"documentation": "https://github.com/acon96/home-llm",
"integration_type": "service",
"iot_class": "local_polling",
"requirements": [
"huggingface-hub>=0.23.0",
"webcolors>=24.8.0"
"webcolors>=24.8.0",
"ollama>=0.5.1"
]
}

View File

@@ -70,7 +70,7 @@
"model_parameters": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API",
"llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt",
"temperature": "Temperature",
"top_k": "Top K",
@@ -109,7 +109,7 @@
"max_tool_call_iterations": "Maximum Tool Call Attempts"
},
"data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
@@ -124,7 +124,7 @@
"reconfigure": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API",
"llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt",
"temperature": "Temperature",
"top_k": "Top K",
@@ -163,7 +163,7 @@
"max_tool_call_iterations": "Maximum Tool Call Attempts"
},
"data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
@@ -177,7 +177,7 @@
}
}
},
"ai_task_data": {
"ai_task": {
"initiate_flow": {
"user": "Add AI Task Handler",
"reconfigure": "Reconfigure AI Task Handler"
@@ -246,7 +246,9 @@
"tool_call_prefix": "Tool Call Prefix",
"tool_call_suffix": "Tool Call Suffix",
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
"max_tool_call_iterations": "Maximum Tool Call Attempts"
"max_tool_call_iterations": "Maximum Tool Call Attempts",
"ai_task_extraction_method": "Structured Data Extraction Method",
"ai_task_retries": "Retry attempts for structured data extraction"
},
"data_description": {
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
@@ -255,7 +257,8 @@
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3).",
"ai_task_extraction_method": "Select the method used to extract structured data from the model's response. 'Structured Output' tells the backend to force the model to produce output following the provided JSON Schema; 'Tool Calling' provides a tool to the model that should be called with the appropriate arguments that match the desired output structure."
},
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
"title": "Configure the selected model"
@@ -263,7 +266,7 @@
"reconfigure": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API",
"llm_hass_api": "Selected LLM API(s)",
"prompt": "System Prompt",
"temperature": "Temperature",
"top_k": "Top K",

View File

@@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs
from homeassistant.util import color
from homeassistant.util.package import install_package, is_installed
from voluptuous_openapi import convert
from voluptuous_openapi import convert as convert_to_openapi
from .const import (
DOMAIN,
@@ -32,7 +32,7 @@ from .const import (
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
)
from typing import TYPE_CHECKING
@@ -275,26 +275,30 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
if llm_api.api.id == HOME_LLM_API_ID:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ]
else:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
} for tool in llm_api.tools ]
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
result: List[ChatCompletionTool] = []
for tool in llm_api.tools:
# when combining with home assistant llm APIs, it adds a prefix to differentiate tools; compare against the suffix here
if tool.name.endswith(SERVICE_TOOL_NAME):
result.extend([{
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ])
else:
result.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
})
return result
@@ -396,41 +400,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
return tools
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
else:
schema_to_validate = vol.Schema({
# try to validate either format
is_services_tool_call = False
try:
base_schema_to_validate = vol.Schema({
vol.Required("name"): str,
vol.Required("arguments"): vol.Union(str, dict),
})
try:
schema_to_validate(parsed_tool_call)
base_schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
try:
home_llm_schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
home_llm_schema_to_validate(parsed_tool_call)
is_services_tool_call = True
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", ""))
args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
if isinstance(args_dict, str):
if not args_dict.strip():

View File

@@ -6,4 +6,5 @@ home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component==0.13.260
# NOTE this must match the version of Home Assistant used for testing
pytest-homeassistant-custom-component==0.13.272

View File

@@ -1,2 +1,3 @@
huggingface-hub>=0.23.0
webcolors>=24.8.0
ollama>=0.5.1

View File

@@ -30,8 +30,8 @@ Start by installing system dependencies:
Then create a Python virtual environment and install all necessary library:
```
python3 -m venv .generate_data
source ./.generate_data/bin/activate
pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0
source .generate_data/bin/activate
pip3 install -r requirements.txt
```
## Generating the dataset from piles

View File

@@ -1,5 +1,5 @@
datasets>=3.2.0
webcolors>=1.13
webcolors>=24.8.0
pandas>=2.2.3
deep-translator>=1.11.4
langcodes>=3.5.0

52
docs/AI Tasks.md Normal file
View File

@@ -0,0 +1,52 @@
# Using AI Tasks
The AI Tasks feature allows you to define structured tasks that your local LLM can perform. These tasks can be integrated into Home Assistant automations and scripts, enabling you to generate dynamic content based on specific prompts and instructions.
## Setting up an AI Task Handler
Setting up a task handler is similar to setting up a conversation agent. You can choose to run the model directly within Home Assistant using `llama-cpp-python`, or you can use an external backend like Ollama. See the [Setup Guide](./docs/Setup.md) for detailed instructions on configuring your AI Task handler.
The specific configuration options for AI Tasks are:
| Option Name | Description |
|-----------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|
| Structured Data Extraction Method | Choose how the AI Task should extract structured data from the model's output. Options include `structured_output` and `tool`. |
| Data Extraction Retry Count | The number of times to retry data extraction if the initial attempt fails. Useful when models can produce incorrect tool responses. |
If no structured data extraction method is specified, then the task entity will always return raw text.
## Using an AI Task in a Script or Automation
To use an AI Task in a Home Assistant script or automation, you can utilize the `ai_task.generate_data` action. This action allows you to specify the task name, instructions, and the structure of the expected output. Below is an example of a script that generates a joke about a smart device in your home.
**Device Joke Script:**
```yaml
sequence:
- action: ai_task.generate_data
data:
task_name: Device Joke Generation
instructions: |
Write a funny joke about one of the smart devices in my home.
Here are all of the smart devices I have:
{% for device in states | rejectattr('domain', 'in', ['update', 'event']) -%}
- {{ device.name }} ({{device.domain}})
{% endfor %}
# You MUST set this to your own LLM entity ID if you do not set a default one in HA Settings
# entity_id: ai_task.unsloth_qwen3_0_6b_gguf_unsloth_qwen3_0_6b_gguf
structure:
joke_setup:
description: The beginning of a joke about a smart device in the home
required: true
selector:
text: null
joke_punchline:
description: The punchline of the same joke about the smart device
required: true
selector:
text: null
response_variable: joke_output
- action: notify.persistent_notification
data:
message: |-
{{ joke_output.data.joke_setup }}
...
{{ joke_output.data.joke_punchline }}
alias: Device Joke
description: "Generates a funny joke about one of the smart devices in the home."
```

View File

@@ -4,7 +4,7 @@ datasets>=3.2.0
peft>=0.14.0
bitsandbytes>=0.45.2
trl>=0.14.0
webcolors>=1.13
webcolors>=24.8.0
pandas>=2.2.3
flash-attn
sentencepiece>=0.2.0

View File

@@ -1,763 +0,0 @@
import json
import logging
import pytest
import jinja2
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT_BASE,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
PROMPT_TEMPLATE_DESCRIPTIONS,
DEFAULT_OPTIONS,
)
from homeassistant.components.conversation import ConversationInput
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API
)
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
_LOGGER = logging.getLogger(__name__)
class WarnDict(dict):
def get(self, _key, _default=None):
if _key in self:
return self[_key]
_LOGGER.warning(f"attempting to get unset dictionary key {_key}")
return _default
class MockConfigEntry:
def __init__(self, entry_id='test_entry_id', data={}, options={}):
self.entry_id = entry_id
self.data = WarnDict(data)
# Use a mutable dict for options in tests
self.options = WarnDict(dict(options))
@pytest.fixture
def config_entry():
yield MockConfigEntry(
data={
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_BACKEND_TYPE: DEFAULT_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE: "/config/models/some-model.q4_k_m.gguf",
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_OPENAI_API_KEY: "OpenAI-API-Key",
CONF_TEXT_GEN_WEBUI_ADMIN_KEY: "Text-Gen-Webui-Admin-Key"
},
options={
**DEFAULT_OPTIONS,
CONF_LLM_HASS_API: LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
}
)
@pytest.fixture
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
entry_mock.return_value = config_entry
llama_instance_mock = MagicMock()
llama_class_mock = MagicMock()
llama_class_mock.return_value = llama_instance_mock
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
install_llama_cpp_python_mock.return_value = True
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
# template_mock.side_affect = lambda template, _: jinja2.Template(template)
generate_mock = llama_instance_mock.generate
generate_mock.return_value = list(range(20))
detokenize_mock = llama_instance_mock.detokenize
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
})).encode()
call_tool_mock.return_value = {"result": "success"}
agent_obj = LlamaCppAgent(
hass,
config_entry
)
all_mocks = {
"llama_class": llama_class_mock,
"tokenize": llama_instance_mock.tokenize,
"generate": generate_mock,
}
yield agent_obj, all_mocks
# TODO: test base llama agent (ICL loading other languages)
async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent: LlamaCppAgent
all_mocks: dict[str, MagicMock]
local_llama_agent, all_mocks = local_llama_agent_fixture
# invoke the conversation agent
conversation_id = "test-conversation"
result = await local_llama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options then apply them
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
local_llama_agent._update_options()
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
# do another turn of the same conversation
result = await local_llama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
@pytest.fixture
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
get_clientsession.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = OllamaAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession.get,
"requests_post": get_clientsession.post
}
yield agent_obj, all_mocks
async def test_ollama_agent(ollama_agent_fixture):
ollama_agent: OllamaAPIAgent
all_mocks: dict[str, MagicMock]
ollama_agent, all_mocks = ollama_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/api/tags",
headers={ "Authorization": "Bearer OpenAI-API-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"created_at": "2023-11-09T21:07:55.186497Z",
"response": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"done": True,
"context": [1, 2, 3],
"total_duration": 4648158584,
"load_duration": 4071084,
"prompt_eval_count": 36,
"prompt_eval_duration": 439038000,
"eval_count": 180,
"eval_duration": 4196918000
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await ollama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/generate",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"prompt": ANY,
"raw": True
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
ollama_agent.entry.options[CONF_TOP_K] = 20
ollama_agent.entry.options[CONF_TOP_P] = 0.9
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
# do another turn of the same conversation
result = await ollama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/chat",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"format": "json",
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"messages": ANY
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
get_clientsession_mock.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = TextGenerationWebuiAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
text_generation_webui_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/v1/internal/model/info",
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
# do another turn of the same conversation and use a preset
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"preset": "Some Preset",
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# change options
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
response_mock.json.return_value = {
"id": "chatcmpl-123",
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
"object": "chat.completions",
"created": 1677652288,
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# do another turn of the same conversation but the chat endpoint
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
call_tool_mock.return_value = {"result": "success"}
agent_obj = GenericOpenAIAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_generic_openai_agent(generic_openai_agent_fixture):
generic_openai_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
generic_openai_agent, all_mocks = generic_openai_agent_fixture
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await generic_openai_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
response_mock.json.return_value = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# do another turn of the same conversation but the chat endpoint
result = await generic_openai_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)

View File

@@ -0,0 +1,144 @@
"""Tests for AI Task extraction behavior."""
from typing import Any, cast
import pytest
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.ai_task import (
LocalLLMTaskEntity,
ResultExtractionMethod,
)
from custom_components.llama_conversation.const import (
CONF_AI_TASK_EXTRACTION_METHOD,
)
from custom_components.llama_conversation.entity import TextGenerationResult
class DummyGenTask:
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
self.name = name
self.instructions = instructions
self.attachments = attachments or []
self.structure = structure
class DummyChatLog:
def __init__(self, content=None):
self.content = content or []
self.conversation_id = "conv-id"
self.llm_api = None
class DummyClient:
def __init__(self, result: TextGenerationResult):
self._result = result
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
return prompt_template
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
return False
async def _generate(self, _messages, _llm_api, _entity_id, _options):
return self._result
class DummyTaskEntity(LocalLLMTaskEntity):
def __init__(self, hass, client, options):
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
self.hass = hass
self.client = client
self._runtime_options = options
self.entity_id = "ai_task.test"
self.entry_id = "entry"
self.subentry_id = "subentry"
self._attr_supported_features = 0
@property
def runtime_options(self):
return self._runtime_options
@pytest.mark.asyncio
async def test_no_extraction_returns_raw_text(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="raw text")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
)
chat_log = DummyChatLog()
task = DummyGenTask()
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == "raw text"
assert chat_log.llm_api is None
@pytest.mark.asyncio
async def test_structured_output_success(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response='{"foo": 1}')),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"foo": 1}
@pytest.mark.asyncio
async def test_structured_output_invalid_json_raises(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="not-json")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_success(hass):
tool_call = llm.ToolInput("submit_response", {"value": 42})
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"value": 42}
assert chat_log.llm_api is not None
@pytest.mark.asyncio
async def test_tool_extraction_missing_tool_args_raises(hass):
class DummyToolCall:
def __init__(self, tool_args=None):
self.tool_args = tool_args
tool_call = DummyToolCall(tool_args=None)
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))

View File

@@ -0,0 +1,105 @@
"""Lightweight smoke tests for backend helpers.
These avoid backend calls and only cover helper utilities to keep the suite green
while the integration evolves. No integration code is modified.
"""
import pytest
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from custom_components.llama_conversation.backends.llamacpp import snapshot_settings
from custom_components.llama_conversation.backends.ollama import OllamaAPIClient, _normalize_path
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_GBNF_GRAMMAR_FILE,
CONF_PROMPT_CACHING_ENABLED,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_PROMPT_CACHING_ENABLED,
CONF_GENERIC_OPENAI_PATH,
)
@pytest.fixture
async def hass_defaults(hass):
return hass
def test_snapshot_settings_defaults():
options = {CONF_CHAT_MODEL: "test-model"}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == DEFAULT_CONTEXT_LENGTH
assert snap[CONF_LLAMACPP_BATCH_SIZE] == DEFAULT_LLAMACPP_BATCH_SIZE
assert snap[CONF_LLAMACPP_THREAD_COUNT] == DEFAULT_LLAMACPP_THREAD_COUNT
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] == DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert snap[CONF_GBNF_GRAMMAR_FILE] == DEFAULT_GBNF_GRAMMAR_FILE
assert snap[CONF_PROMPT_CACHING_ENABLED] == DEFAULT_PROMPT_CACHING_ENABLED
def test_snapshot_settings_overrides():
options = {
CONF_CONTEXT_LENGTH: 4096,
CONF_LLAMACPP_BATCH_SIZE: 64,
CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_GBNF_GRAMMAR_FILE: "custom.gbnf",
CONF_PROMPT_CACHING_ENABLED: True,
}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == 4096
assert snap[CONF_LLAMACPP_BATCH_SIZE] == 64
assert snap[CONF_LLAMACPP_THREAD_COUNT] == 6
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == 3
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] is True
assert snap[CONF_GBNF_GRAMMAR_FILE] == "custom.gbnf"
assert snap[CONF_PROMPT_CACHING_ENABLED] is True
def test_ollama_keep_alive_formatting():
assert OllamaAPIClient._format_keep_alive("0") == 0
assert OllamaAPIClient._format_keep_alive("0.0") == 0
assert OllamaAPIClient._format_keep_alive(5) == "5m"
assert OllamaAPIClient._format_keep_alive("15") == "15m"
def test_generic_openai_name_and_path(hass_defaults):
client = GenericOpenAIAPIClient(
hass_defaults,
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_CHAT_MODEL: "demo",
},
)
name = client.get_name(
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
}
)
assert "Generic OpenAI" in name
assert "localhost" in name
def test_normalize_path_helper():
assert _normalize_path(None) == ""
assert _normalize_path("") == ""
assert _normalize_path("/v1/") == "/v1"
assert _normalize_path("v2") == "/v2"

View File

@@ -1,350 +1,204 @@
"""Config flow option schema tests to ensure options are wired per-backend."""
import pytest
from unittest.mock import patch, MagicMock
from homeassistant import config_entries, setup
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API,
)
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import llm
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_TOOL_FORMAT,
CONF_TOOL_MULTI_TURN_CHAT,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA,
DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_CONTEXT_LENGTH,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_GBNF_GRAMMAR_FILE,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_LLAMACPP_THREAD_COUNT,
CONF_MAX_TOKENS,
CONF_MIN_P,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_OLLAMA_JSON_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_PROMPT,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_REQUEST_TIMEOUT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_THINKING_PREFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_TEMPERATURE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
DOMAIN,
DEFAULT_PROMPT,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_THINKING_PREFIX,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
)
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
# result = await hass.config_entries.flow.async_init(
# DOMAIN, context={"source": config_entries.SOURCE_USER}
# )
# assert result["type"] == FlowResultType.FORM
# assert result["errors"] is None
# result2 = await hass.config_entries.flow.async_configure(
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
# )
# assert result2["type"] == FlowResultType.FORM
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
# result3 = await hass.config_entries.flow.async_configure(
# result2["flow_id"],
# TEST_DATA,
# )
# await hass.async_block_till_done()
# assert result3["type"] == "create_entry"
# assert result3["title"] == ""
# assert result3["data"] == {
# # ACCOUNT_ID: TEST_DATA["account_id"],
# # CONF_PASSWORD: TEST_DATA["password"],
# # CONNECTION_TYPE: CLOUD,
# }
# assert result3["options"] == {}
# assert len(mock_setup_entry.mock_calls) == 1
@pytest.fixture
def validate_connections_mock():
validate_mock = MagicMock()
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
yield validate_mock
@pytest.fixture
def mock_setup_entry():
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
yield mock_setup_entry
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
def _schema(hass: HomeAssistant, backend: str, options: dict | None = None):
return local_llama_config_option_schema(
hass=hass,
language="en",
options=options or {},
backend_type=backend,
subentry_type="conversation",
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
def _get_default(schema: dict, key_name: str):
for key in schema:
if getattr(key, "schema", None) == key_name:
default = getattr(key, "default", None)
return default() if callable(default) else default
raise AssertionError(f"Key {key_name} not found in schema")
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
def _get_suggested(schema: dict, key_name: str):
for key in schema:
if getattr(key, "schema", None) == key_name:
return (getattr(key, "description", {}) or {}).get("suggested_value")
raise AssertionError(f"Key {key_name} not found in schema")
def test_schema_llama_cpp_defaults_and_overrides(hass: HomeAssistant):
overrides = {
CONF_CONTEXT_LENGTH: 4096,
CONF_LLAMACPP_BATCH_SIZE: 8,
CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_PROMPT_CACHING_INTERVAL: 15,
CONF_TOP_K: 12,
CONF_TOOL_CALL_PREFIX: "<tc>",
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
)
await hass.async_block_till_done()
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP, overrides)
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
expected_keys = {
CONF_MAX_TOKENS,
CONF_CONTEXT_LENGTH,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_GBNF_GRAMMAR_FILE,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
}
assert result4["options"] == options_dict
assert len(mock_setup_entry.mock_calls) == 1
assert expected_keys.issubset({getattr(k, "schema", None) for k in schema})
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_default(schema, CONF_LLAMACPP_BATCH_SIZE) == DEFAULT_LLAMACPP_BATCH_SIZE
assert _get_default(schema, CONF_LLAMACPP_THREAD_COUNT) == DEFAULT_LLAMACPP_THREAD_COUNT
assert _get_default(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert _get_default(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert _get_default(schema, CONF_PROMPT_CACHING_INTERVAL) == DEFAULT_PROMPT_CACHING_INTERVAL
# suggested values should reflect overrides
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 4096
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_SIZE) == 8
assert _get_suggested(schema, CONF_LLAMACPP_THREAD_COUNT) == 6
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == 3
assert _get_suggested(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is True
assert _get_suggested(schema, CONF_PROMPT_CACHING_INTERVAL) == 15
assert _get_suggested(schema, CONF_TOP_K) == 12
assert _get_suggested(schema, CONF_TOOL_CALL_PREFIX) == "<tc>"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
# simulate incorrect settings on first try
validate_connections_mock.side_effect = [
("failed_to_connect", Exception("ConnectionError"), []),
(None, None, [])
]
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert len(result3["errors"]) == 1
assert "base" in result3["errors"]
assert result3["step_id"] == "remote_model"
# retry
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TYPICAL_P: DEFAULT_MIN_P,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
def test_schema_text_gen_webui_options_preserved(hass: HomeAssistant):
overrides = {
CONF_REQUEST_TIMEOUT: 123,
CONF_TEXT_GEN_WEBUI_PRESET: "custom-preset",
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_CONTEXT_LENGTH: 2048,
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
schema = _schema(hass, BACKEND_TYPE_TEXT_GEN_WEBUI, overrides)
expected = {CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REQUEST_TIMEOUT, CONF_CONTEXT_LENGTH}
assert expected.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 123
assert _get_suggested(schema, CONF_TEXT_GEN_WEBUI_PRESET) == "custom-preset"
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 2048
def test_schema_generic_openai_options_preserved(hass: HomeAssistant):
overrides = {CONF_TOP_P: 0.25, CONF_REQUEST_TIMEOUT: 321}
schema = _schema(hass, BACKEND_TYPE_GENERIC_OPENAI, overrides)
assert {CONF_TOP_P, CONF_REQUEST_TIMEOUT}.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_TOP_P) == DEFAULT_TOP_P
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_suggested(schema, CONF_TOP_P) == 0.25
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 321
# Base prompt options still present
prompt_default = _get_default(schema, CONF_PROMPT)
assert prompt_default is not None and "You are 'Al'" in prompt_default
assert _get_default(schema, CONF_NUM_IN_CONTEXT_EXAMPLES) == DEFAULT_NUM_IN_CONTEXT_EXAMPLES
def test_schema_llama_cpp_server_includes_gbnf(hass: HomeAssistant):
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP_SERVER)
keys = {getattr(k, "schema", None) for k in schema}
assert {CONF_MAX_TOKENS, CONF_TOP_K, CONF_GBNF_GRAMMAR_FILE}.issubset(keys)
assert _get_default(schema, CONF_GBNF_GRAMMAR_FILE) == "output.gbnf"
def test_schema_ollama_defaults_and_overrides(hass: HomeAssistant):
overrides = {CONF_OLLAMA_KEEP_ALIVE_MIN: 5, CONF_CONTEXT_LENGTH: 1024, CONF_TOP_K: 7}
schema = _schema(hass, BACKEND_TYPE_OLLAMA, overrides)
assert {CONF_MAX_TOKENS, CONF_CONTEXT_LENGTH, CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE}.issubset(
{getattr(k, "schema", None) for k in schema}
)
await hass.async_block_till_done()
assert _get_default(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == DEFAULT_OLLAMA_KEEP_ALIVE_MIN
assert _get_default(schema, CONF_OLLAMA_JSON_MODE) is DEFAULT_OLLAMA_JSON_MODE
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_default(schema, CONF_TOP_K) == DEFAULT_TOP_K
assert _get_suggested(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == 5
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 1024
assert _get_suggested(schema, CONF_TOP_K) == 7
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
}
assert result4["options"] == options_dict
mock_setup_entry.assert_called_once()
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
def test_schema_includes_llm_api_selector(monkeypatch, hass: HomeAssistant):
monkeypatch.setattr(
"custom_components.llama_conversation.config_flow.llm.async_get_apis",
lambda _hass: [type("API", (), {"id": "dummy", "name": "Dummy API", "tools": []})()],
)
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP)
def test_validate_options_schema(hass: HomeAssistant):
universal_options = [
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
assert set(options_generic_openai.keys()) == set(universal_options + [
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
assert _get_default(schema, CONF_LLM_HASS_API) is None
# Base prompt and thinking prefixes use defaults when not overridden
prompt_default = _get_default(schema, CONF_PROMPT)
assert prompt_default is not None and "You are 'Al'" in prompt_default
assert _get_default(schema, CONF_THINKING_PREFIX) == DEFAULT_THINKING_PREFIX
assert _get_default(schema, CONF_TOOL_CALL_PREFIX) == DEFAULT_TOOL_CALL_PREFIX

View File

@@ -0,0 +1,114 @@
"""Tests for LocalLLMAgent async_process."""
import pytest
from contextlib import contextmanager
from homeassistant.components.conversation import ConversationInput, SystemContent, AssistantContent
from homeassistant.const import MATCH_ALL
from custom_components.llama_conversation.conversation import LocalLLMAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_PROMPT,
DOMAIN,
)
class DummyClient:
def __init__(self, hass):
self.hass = hass
self.generated_prompts = []
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
self.generated_prompts.append(prompt_template)
return "rendered-system-prompt"
async def _async_generate(self, conv, agent_id, chat_log, entity_options):
async def gen():
yield AssistantContent(agent_id=agent_id, content="hello from llm")
return gen()
class DummySubentry:
def __init__(self, subentry_id="sub1", title="Test Agent", chat_model="model"):
self.subentry_id = subentry_id
self.title = title
self.subentry_type = DOMAIN
self.data = {CONF_CHAT_MODEL: chat_model}
class DummyEntry:
def __init__(self, entry_id="entry1", options=None, subentry=None, runtime_data=None):
self.entry_id = entry_id
self.options = options or {}
self.subentries = {subentry.subentry_id: subentry}
self.runtime_data = runtime_data
def add_update_listener(self, _cb):
return lambda: None
class FakeChatLog:
def __init__(self):
self.content = []
self.llm_api = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeChatSession:
def __enter__(self):
return {}
def __exit__(self, exc_type, exc, tb):
return False
@pytest.mark.asyncio
async def test_async_process_generates_response(monkeypatch, hass):
client = DummyClient(hass)
subentry = DummySubentry()
entry = DummyEntry(subentry=subentry, runtime_data=client)
# Make entry discoverable through hass data as LocalLLMEntity expects.
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
@contextmanager
def fake_chat_session(_hass, _conversation_id):
yield FakeChatSession()
@contextmanager
def fake_chat_log(_hass, _session, _user_input):
yield FakeChatLog()
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.chat_session.async_get_chat_session",
fake_chat_session,
)
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.conversation.async_get_chat_log",
fake_chat_log,
)
agent = LocalLLMAgent(hass, entry, subentry, client)
result = await agent.async_process(
ConversationInput(
text="turn on the lights",
context=None,
conversation_id="conv-id",
device_id=None,
language="en",
agent_id="agent-1",
)
)
assert result.response.speech["plain"]["speech"] == "hello from llm"
# System prompt should be rendered once when message history is empty.
assert client.generated_prompts == [DEFAULT_PROMPT]
assert agent.supported_languages == MATCH_ALL

View File

@@ -0,0 +1,162 @@
"""Tests for LocalLLMClient helpers in entity.py."""
import inspect
import json
import pytest
from json import JSONDecodeError
from custom_components.llama_conversation.entity import LocalLLMClient
from custom_components.llama_conversation.const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_THINKING_PREFIX,
DEFAULT_THINKING_SUFFIX,
)
class DummyLocalClient(LocalLLMClient):
@staticmethod
def get_name(_client_options):
return "dummy"
class DummyLLMApi:
def __init__(self):
self.tools = []
@pytest.fixture
def client(hass):
# Disable ICL loading during tests to avoid filesystem access.
return DummyLocalClient(hass, {CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False})
@pytest.mark.asyncio
async def test_async_parse_completion_parses_tool_call(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":0.5,"to_say":" acknowledged"}}'
completion = (
f"{DEFAULT_THINKING_PREFIX}internal{DEFAULT_THINKING_SUFFIX}"
f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
)
result = await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, completion)
assert result.response.strip().startswith("hello")
assert "acknowledged" in result.response
assert result.tool_calls
tool_call = result.tool_calls[0]
assert tool_call.tool_name == "light.turn_on"
assert tool_call.tool_args["brightness"] == 127
@pytest.mark.asyncio
async def test_async_parse_completion_ignores_tools_without_llm_api(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":1}}'
completion = f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
result = await client._async_parse_completion(None, "agent-id", {}, completion)
assert result.tool_calls == []
assert result.response.strip() == "hello"
@pytest.mark.asyncio
async def test_async_parse_completion_malformed_tool_raises(client):
bad_tool = f"{DEFAULT_TOOL_CALL_PREFIX}{{not-json{DEFAULT_TOOL_CALL_SUFFIX}"
with pytest.raises(JSONDecodeError):
await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, bad_tool)
@pytest.mark.asyncio
async def test_async_stream_parse_completion_handles_streamed_tool_call(client):
async def token_generator():
yield ("Hi", None)
yield (
None,
[
{
"function": {
"name": "light.turn_on",
"arguments": {"brightness": 0.25, "to_say": " ok"},
}
}
],
)
stream = client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
)
results = [chunk async for chunk in stream]
assert results[0].response == "Hi"
assert results[1].response.strip() == "ok"
assert results[1].tool_calls[0].tool_args["brightness"] == 63
@pytest.mark.asyncio
async def test_async_stream_parse_completion_malformed_tool_raises(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{not-json"])
with pytest.raises(JSONDecodeError):
async for _chunk in client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
):
pass
@pytest.mark.asyncio
async def test_async_stream_parse_completion_ignores_tools_without_llm_api(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{}"])
results = [chunk async for chunk in client._async_stream_parse_completion(
None, "agent-id", {}, anext_token=token_generator()
)]
assert results[0].response == "Hi"
assert results[1].tool_calls is None
@pytest.mark.asyncio
async def test_async_get_exposed_entities_respects_exposure(monkeypatch, client, hass):
hass.states.async_set("light.exposed", "on", {"friendly_name": "Lamp"})
hass.states.async_set("switch.hidden", "off", {"friendly_name": "Hidden"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, entity_id: not entity_id.endswith("hidden"),
)
exposed = client._async_get_exposed_entities()
assert "light.exposed" in exposed
assert "switch.hidden" not in exposed
assert exposed["light.exposed"]["friendly_name"] == "Lamp"
assert exposed["light.exposed"]["state"] == "on"
@pytest.mark.asyncio
async def test_generate_system_prompt_renders(monkeypatch, client, hass):
hass.states.async_set("light.kitchen", "on", {"friendly_name": "Kitchen"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, _entity_id: True,
)
rendered = client._generate_system_prompt(
"Devices:\n{{ formatted_devices }}",
llm_api=None,
entity_options={CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: []},
)
if inspect.iscoroutine(rendered):
rendered = await rendered
assert isinstance(rendered, str)
assert "light.kitchen" in rendered

View File

@@ -0,0 +1,159 @@
"""Regression tests for config entry migration in __init__.py."""
import pytest
from homeassistant.const import CONF_LLM_HASS_API, CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.config_entries import ConfigSubentry
from pytest_homeassistant_custom_component.common import MockConfigEntry
from custom_components.llama_conversation import async_migrate_entry
from custom_components.llama_conversation.const import (
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_SERVER,
CONF_BACKEND_TYPE,
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_DOWNLOADED_MODEL_FILE,
CONF_DOWNLOADED_MODEL_QUANTIZATION,
CONF_GENERIC_OPENAI_PATH,
CONF_PROMPT,
CONF_REQUEST_TIMEOUT,
DOMAIN,
)
@pytest.mark.asyncio
async def test_migrate_v1_is_rejected(hass):
entry = MockConfigEntry(domain=DOMAIN, data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP}, version=1)
entry.add_to_hass(hass)
result = await async_migrate_entry(hass, entry)
assert result is False
@pytest.mark.asyncio
async def test_migrate_v2_creates_subentry_and_updates_entry(monkeypatch, hass):
entry = MockConfigEntry(
domain=DOMAIN,
title="llama 'Test Agent' entry",
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_PROMPT: "hello",
CONF_REQUEST_TIMEOUT: 90,
CONF_CHAT_MODEL: "model-x",
CONF_CONTEXT_LENGTH: 1024,
},
version=2,
)
entry.add_to_hass(hass)
added_subentries = []
update_calls = []
def fake_add_subentry(cfg_entry, subentry):
added_subentries.append((cfg_entry, subentry))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(hass.config_entries, "async_add_subentry", fake_add_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert added_subentries, "Subentry should be added"
subentry = added_subentries[0][1]
assert isinstance(subentry, ConfigSubentry)
assert subentry.subentry_type == "conversation"
assert subentry.data[CONF_CHAT_MODEL] == "model-x"
# Entry should be updated to version 3 with data/options separated
assert any(call.get("version") == 3 for call in update_calls)
last_options = [c["options"] for c in update_calls if "options" in c][-1]
assert last_options[CONF_HOST] == "localhost"
assert CONF_PROMPT not in last_options # moved to subentry
@pytest.mark.asyncio
async def test_migrate_v3_minor0_downloads_model(monkeypatch, hass):
sub_data = {
CONF_CHAT_MODEL: "model-a",
CONF_DOWNLOADED_MODEL_QUANTIZATION: "Q4_K_M",
CONF_REQUEST_TIMEOUT: 30,
}
subentry = ConfigSubentry(data=sub_data, subentry_type="conversation", title="sub", unique_id=None)
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP},
options={},
version=3,
minor_version=0,
)
entry.subentries = {"sub": subentry}
entry.add_to_hass(hass)
updated_subentries = []
update_calls = []
def fake_update_subentry(cfg_entry, old_sub, *, data=None, **_kwargs):
updated_subentries.append((cfg_entry, old_sub, data))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(
"custom_components.llama_conversation.download_model_from_hf", lambda *_args, **_kw: "file.gguf"
)
monkeypatch.setattr(hass.config_entries, "async_update_subentry", fake_update_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert updated_subentries, "Subentry should be updated with downloaded file"
new_data = updated_subentries[0][2]
assert new_data[CONF_DOWNLOADED_MODEL_FILE] == "file.gguf"
assert any(call.get("minor_version") == 1 for call in update_calls)
@pytest.mark.parametrize(
"api_value,expected_list",
[("api-1", ["api-1"]), (None, [])],
)
@pytest.mark.asyncio
async def test_migrate_v3_minor1_converts_api_to_list(monkeypatch, hass, api_value, expected_list):
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={CONF_LLM_HASS_API: api_value},
version=3,
minor_version=1,
)
entry.add_to_hass(hass)
calls = []
def fake_update_entry(cfg_entry, **kwargs):
calls.append(kwargs)
if "options" in kwargs:
cfg_entry._options = kwargs["options"] # type: ignore[attr-defined]
if "minor_version" in kwargs:
cfg_entry._minor_version = kwargs["minor_version"] # type: ignore[attr-defined]
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
options_calls = [c for c in calls if "options" in c]
assert options_calls, "async_update_entry should be called with options"
assert options_calls[-1]["options"][CONF_LLM_HASS_API] == expected_list
minor_calls = [c for c in calls if c.get("minor_version")]
assert minor_calls and minor_calls[-1]["minor_version"] == 2