mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
7
.gitignore
vendored
7
.gitignore
vendored
@@ -3,12 +3,13 @@ loras/
|
||||
core/
|
||||
config/
|
||||
.DS_Store
|
||||
data/*.json
|
||||
data/*.jsonl
|
||||
data/**/*.json
|
||||
data/**/*.jsonl
|
||||
*.pyc
|
||||
main.log
|
||||
.venv
|
||||
*.xlsx
|
||||
notes.txt
|
||||
runpod_bootstrap.sh
|
||||
*.code-workspace
|
||||
*.code-workspace
|
||||
.coverage
|
||||
@@ -4,16 +4,17 @@ This project provides the required "glue" components to control your Home Assist
|
||||
## Quick Start
|
||||
Please see the [Setup Guide](./docs/Setup.md) for more information on installation.
|
||||
|
||||
## Local LLM Conversation Integration
|
||||
## Local LLM Integration
|
||||
**The latest version of this integration requires Home Assistant 2025.7.0 or newer**
|
||||
|
||||
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent".
|
||||
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent" or as an "ai task handler".
|
||||
|
||||
This component can be interacted with in a few ways:
|
||||
- using a chat interface so you can chat with it.
|
||||
- integrating with Speech-to-Text and Text-to-Speech addons so you can just speak to it.
|
||||
- using automations or scripts to trigger "ai tasks"; these process input data with a prompt, and return structured data that can be used in further automations.
|
||||
|
||||
The integration can either run the model in 2 different ways:
|
||||
The integration can either run the model in a few ways:
|
||||
1. Directly as part of the Home Assistant software using llama-cpp-python
|
||||
2. On a separate machine using one of the following backends:
|
||||
- [Ollama](https://ollama.com/) (easier)
|
||||
@@ -36,6 +37,7 @@ The latest models can be found on HuggingFace:
|
||||
|
||||
**Gemma3**:
|
||||
1B: TBD
|
||||
270M: TBD
|
||||
|
||||
<details>
|
||||
|
||||
@@ -158,6 +160,7 @@ python3 train.py \
|
||||
## Version History
|
||||
| Version | Description |
|
||||
|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| v0.4.5 | Add support for AI Task entities, Replace custom Ollama API implementation with the official `ollama-python` package to avoid future compatibility issues, Support multiple LLM APIs at once, Fix issues in tool call handling for various backends |
|
||||
| v0.4.4 | Fix issue with OpenAI backends appending `/v1` to all URLs, and fix an issue with tools being serialized into the system prompt. |
|
||||
| v0.4.3 | Fix an issue with the integration not creating model configs properly during setup |
|
||||
| v0.4.2 | Fix the following issues: not correctly setting default model settings during initial setup, non-integers being allowed in numeric config fields, being too strict with finish_reason requirements, and not letting the user clear the active LLM API |
|
||||
|
||||
@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK)
|
||||
PLATFORMS = (Platform.CONVERSATION, Platform.AI_TASK)
|
||||
|
||||
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
|
||||
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
|
||||
@@ -184,6 +184,21 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
|
||||
|
||||
_LOGGER.debug("Migration to add downloaded model file complete")
|
||||
|
||||
if config_entry.version == 3 and config_entry.minor_version == 1:
|
||||
# convert selected APIs from single value to list
|
||||
api_to_convert = config_entry.options.get(CONF_LLM_HASS_API)
|
||||
new_options = dict(config_entry.options)
|
||||
if api_to_convert is not None:
|
||||
new_options[CONF_LLM_HASS_API] = [api_to_convert]
|
||||
else:
|
||||
new_options[CONF_LLM_HASS_API] = []
|
||||
|
||||
hass.config_entries.async_update_entry(
|
||||
config_entry, options=MappingProxyType(new_options)
|
||||
)
|
||||
hass.config_entries.async_update_entry(config_entry, minor_version=2)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
class HassServiceTool(llm.Tool):
|
||||
@@ -255,14 +270,14 @@ class HomeLLMAPI(llm.API):
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id=HOME_LLM_API_ID,
|
||||
name="Home-LLM (v1-v3)",
|
||||
name="Home Assistant Services",
|
||||
)
|
||||
|
||||
async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance:
|
||||
"""Return the instance of the API."""
|
||||
return llm.APIInstance(
|
||||
api=self,
|
||||
api_prompt="Call services in Home Assistant by passing the service name and the device to control.",
|
||||
api_prompt="Call services in Home Assistant by passing the service name and the device to control. Designed for Home-LLM Models (v1-v3)",
|
||||
llm_context=llm_context,
|
||||
tools=[HassServiceTool()],
|
||||
)
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""AI Task integration for Local LLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert as convert_to_openapi
|
||||
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, Context
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
@@ -16,7 +20,13 @@ from homeassistant.util.json import json_loads
|
||||
from .entity import LocalLLMEntity, LocalLLMClient
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_PROMPT,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
DEFAULT_AI_TASK_PROMPT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
DEFAULT_AI_TASK_RETRIES,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -29,98 +39,222 @@ async def async_setup_entry(
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
if subentry.subentry_type != ai_task.DOMAIN:
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
# create one entity per subentry
|
||||
ai_task_entity = LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)
|
||||
|
||||
# make sure model is loaded
|
||||
await config_entry.runtime_data._async_load_model(dict(subentry.data))
|
||||
|
||||
# register the ai task entity
|
||||
async_add_entities([ai_task_entity], config_subentry_id=subentry.subentry_id)
|
||||
|
||||
|
||||
class ResultExtractionMethod(StrEnum):
|
||||
NONE = "none"
|
||||
STRUCTURED_OUTPUT = "structure"
|
||||
TOOL = "tool"
|
||||
|
||||
class SubmitResponseTool(llm.Tool):
|
||||
name = "submit_response"
|
||||
description = "Submit the structured response payload for the AI task"
|
||||
|
||||
def __init__(self, parameters_schema: vol.Schema):
|
||||
self.parameters = parameters_schema
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> dict:
|
||||
return tool_input.tool_args or {}
|
||||
|
||||
|
||||
class SubmitResponseAPI(llm.API):
|
||||
def __init__(self, hass: HomeAssistant, tools: list[llm.Tool]) -> None:
|
||||
self._tools = tools
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id=f"{DOMAIN}-ai-task-tool",
|
||||
name="AI Task Tool API",
|
||||
)
|
||||
|
||||
async def async_get_api_instance(
|
||||
self, llm_context: llm.LLMContext
|
||||
) -> llm.APIInstance:
|
||||
return llm.APIInstance(
|
||||
api=self,
|
||||
api_prompt="Call submit_response to return the structured AI task result.",
|
||||
llm_context=llm_context,
|
||||
tools=self._tools,
|
||||
custom_serializer=llm.selector_serializer,
|
||||
)
|
||||
|
||||
|
||||
class LocalLLMTaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
LocalLLMEntity,
|
||||
):
|
||||
"""Ollama AI Task entity."""
|
||||
"""AI Task entity."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize Ollama AI Task entity."""
|
||||
"""Initialize AI Task entity."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.client._supports_vision(self.runtime_options):
|
||||
self._attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA |
|
||||
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
else:
|
||||
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _generate_once(
|
||||
self,
|
||||
message_history: list[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
entity_options: dict[str, Any],
|
||||
) -> tuple[str, list | None, Exception | None]:
|
||||
"""Generate a single response from the LLM."""
|
||||
collected_tools = None
|
||||
text = ""
|
||||
|
||||
# call the LLM client directly (not _async_generate) since that will attempt to execute tool calls
|
||||
try:
|
||||
if hasattr(self.client, "_generate_stream"):
|
||||
async for chunk in self.client._generate_stream(
|
||||
message_history,
|
||||
llm_api,
|
||||
self.entity_id,
|
||||
entity_options,
|
||||
):
|
||||
if chunk.response:
|
||||
text += chunk.response.strip()
|
||||
if chunk.tool_calls:
|
||||
collected_tools = chunk.tool_calls
|
||||
else:
|
||||
blocking_result = await self.client._generate(
|
||||
message_history,
|
||||
llm_api,
|
||||
self.entity_id,
|
||||
entity_options,
|
||||
)
|
||||
if blocking_result.response:
|
||||
text = blocking_result.response.strip()
|
||||
if blocking_result.tool_calls:
|
||||
collected_tools = blocking_result.tool_calls
|
||||
|
||||
_LOGGER.debug("AI Task '%s' generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
|
||||
return text, collected_tools, None
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.debug("AI Task '%s' json error generated text: %s (tools=%s)", self.entity_id, text, collected_tools)
|
||||
return text, collected_tools, err
|
||||
|
||||
def _extract_data(
|
||||
self,
|
||||
raw_text: str,
|
||||
tool_calls: list[llm.ToolInput] | None,
|
||||
extraction_method: ResultExtractionMethod,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None,
|
||||
) -> tuple[ai_task.GenDataTaskResult | None, Exception | None]:
|
||||
"""Extract the final data from the LLM response based on the extraction method."""
|
||||
try:
|
||||
if extraction_method == ResultExtractionMethod.NONE or structure is None:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=raw_text,
|
||||
), None
|
||||
|
||||
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
data = json_loads(raw_text)
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
), None
|
||||
|
||||
if extraction_method == ResultExtractionMethod.TOOL:
|
||||
first_tool = next(iter(tool_calls or []), None)
|
||||
if not first_tool:
|
||||
return None, HomeAssistantError("Please produce at least one tool call with the structured response.")
|
||||
|
||||
structure(first_tool.tool_args) # validate tool call against vol schema structure
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=first_tool.tool_args,
|
||||
), None
|
||||
except vol.Invalid as err:
|
||||
if isinstance(err, vol.MultipleInvalid):
|
||||
# combine all error messages into one
|
||||
error_message = "; ".join(f"Error at '{e.path}': {e.error_message}" for e in err.errors)
|
||||
else:
|
||||
error_message = f"Error at '{err.path}': {err.error_message}"
|
||||
return None, HomeAssistantError(f"Please address the following schema errors: {error_message}")
|
||||
except JSONDecodeError as err:
|
||||
return None, HomeAssistantError(f"Please produce properly formatted JSON: {repr(err)}")
|
||||
|
||||
raise HomeAssistantError(f"Invalid extraction method for AI Task {extraction_method}")
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
raw_task_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)
|
||||
retries = max(0, self.runtime_options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES))
|
||||
extraction_method = self.runtime_options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)
|
||||
max_attempts = retries + 1
|
||||
|
||||
extraction_method = ResultExtractionMethod.NONE
|
||||
|
||||
try:
|
||||
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
|
||||
message_history = chat_log.content[:]
|
||||
|
||||
if not isinstance(message_history[0], conversation.SystemContent):
|
||||
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
|
||||
message_history.insert(0, system_prompt)
|
||||
|
||||
_LOGGER.debug(f"Generating response for {task.name=}...")
|
||||
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
|
||||
|
||||
assistant_message = await anext(generation_result)
|
||||
if not isinstance(assistant_message, conversation.AssistantContent):
|
||||
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
|
||||
text = assistant_message.content
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.NONE:
|
||||
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
|
||||
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.error(
|
||||
"Failed to parse JSON response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM structured response") from err
|
||||
entity_options = {**self.runtime_options}
|
||||
if task.structure: # set up extraction method specifics
|
||||
if extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
_LOGGER.debug("Using structure for AI Task '%s': %s", task.name, task.structure)
|
||||
entity_options[CONF_RESPONSE_JSON_SCHEMA] = convert_to_openapi(task.structure, custom_serializer=llm.selector_serializer)
|
||||
elif extraction_method == ResultExtractionMethod.TOOL:
|
||||
try:
|
||||
data = assistant_message.tool_calls[0].tool_args
|
||||
except (IndexError, AttributeError) as err:
|
||||
_LOGGER.error(
|
||||
"Failed to extract tool arguments from response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM tool response") from err
|
||||
else:
|
||||
raise ValueError() # should not happen
|
||||
chat_log.llm_api = await SubmitResponseAPI(self.hass, [SubmitResponseTool(task.structure)]).async_get_api_instance(
|
||||
llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None)
|
||||
)
|
||||
|
||||
message_history = list(chat_log.content) if chat_log.content else []
|
||||
task_prompt = self.client._generate_system_prompt(raw_task_prompt, llm_api=chat_log.llm_api, entity_options=entity_options)
|
||||
system_message = conversation.SystemContent(content=task_prompt)
|
||||
if message_history and isinstance(message_history[0], conversation.SystemContent):
|
||||
message_history[0] = system_message
|
||||
else:
|
||||
message_history.insert(0, system_message)
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
if not any(isinstance(msg, conversation.UserContent) for msg in message_history):
|
||||
message_history.append(
|
||||
conversation.UserContent(
|
||||
content=task.instructions, attachments=task.attachments
|
||||
)
|
||||
)
|
||||
try:
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(max_attempts):
|
||||
_LOGGER.debug("Generating response for %s (attempt %s/%s)...", task.name, attempt + 1, max_attempts)
|
||||
text, tool_calls, err = await self._generate_once(message_history, chat_log.llm_api, entity_options)
|
||||
if err:
|
||||
last_error = err
|
||||
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
|
||||
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
|
||||
continue
|
||||
|
||||
data, err = self._extract_data(text, tool_calls, extraction_method, chat_log, task.structure)
|
||||
if err:
|
||||
last_error = err
|
||||
message_history.append(conversation.AssistantContent(agent_id=self.entity_id, content=text, tool_calls=tool_calls))
|
||||
message_history.append(conversation.UserContent(content=f"Error: {str(err)}. Please try again."))
|
||||
continue
|
||||
|
||||
if data:
|
||||
return data
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
|
||||
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
|
||||
|
||||
raise last_error or HomeAssistantError(f"AI Task '{task.name}' failed after {max_attempts} attempts")
|
||||
|
||||
@@ -28,6 +28,7 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
@@ -110,7 +111,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
except:
|
||||
except (asyncio.TimeoutError, aiohttp.ClientResponseError):
|
||||
_LOGGER.exception("Failed to get available models")
|
||||
return RECOMMENDED_CHAT_MODELS
|
||||
|
||||
@@ -120,7 +121,8 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
entity_options: dict[str, Any],
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options[CONF_CHAT_MODEL]
|
||||
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
@@ -138,6 +140,17 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
if response_json_schema:
|
||||
request_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ha_task",
|
||||
"schema": response_json_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
tools = None
|
||||
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
|
||||
# most local backends absolutely butcher any sort of prompt formatting when using tool calling
|
||||
@@ -155,7 +168,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
try:
|
||||
@@ -175,7 +188,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
break
|
||||
|
||||
if chunk and chunk.strip():
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk))
|
||||
if to_say or tool_calls:
|
||||
yield to_say, tool_calls
|
||||
except asyncio.TimeoutError as err:
|
||||
@@ -183,14 +196,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
endpoint = "/chat/completions"
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
|
||||
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
|
||||
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
|
||||
list(response_json.keys()), response_json)
|
||||
@@ -204,16 +217,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
elif response_json["object"] == "chat.completion.chunk":
|
||||
response_text = choice["delta"].get("content", "")
|
||||
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
tool_call, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api, agent_id)
|
||||
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if to_say:
|
||||
response_text += to_say
|
||||
tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
|
||||
streamed = True
|
||||
else:
|
||||
response_text = choice["text"]
|
||||
@@ -267,7 +271,6 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
try:
|
||||
if msg.role == "user":
|
||||
input_text = msg.content
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -367,7 +370,8 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
entity_options: dict[str, Any],
|
||||
) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
|
||||
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL)
|
||||
@@ -377,6 +381,16 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
request_params: Dict[str, Any] = {
|
||||
"model": model_name,
|
||||
}
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
if response_json_schema:
|
||||
request_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ha_task",
|
||||
"schema": response_json_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers: Dict[str, Any] = {}
|
||||
@@ -398,7 +412,10 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
|
||||
try:
|
||||
text = self._extract_response(response_json)
|
||||
return TextGenerationResult(response=text, response_streamed=False)
|
||||
if not text:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
|
||||
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
|
||||
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")
|
||||
|
||||
@@ -9,7 +9,6 @@ import time
|
||||
from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
@@ -57,6 +56,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DOMAIN,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
)
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
|
||||
|
||||
@@ -64,10 +64,15 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
|
||||
from typing import TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType
|
||||
from llama_cpp import (
|
||||
Llama as LlamaType,
|
||||
LlamaGrammar as LlamaGrammarType,
|
||||
ChatCompletionRequestResponseFormat
|
||||
)
|
||||
else:
|
||||
LlamaType = Any
|
||||
LlamaGrammarType = Any
|
||||
ChatCompletionRequestResponseFormat = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -283,8 +288,6 @@ class LlamaCppClient(LocalLLMClient):
|
||||
# Sort the items based on the sort_key function
|
||||
sorted_items = sorted(list(entity_order.items()), key=sort_key)
|
||||
|
||||
_LOGGER.debug(f"sorted_items: {sorted_items}")
|
||||
|
||||
sorted_entities: dict[str, dict[str, str]] = {}
|
||||
for item_name, _ in sorted_items:
|
||||
sorted_entities[item_name] = entities[item_name]
|
||||
@@ -297,7 +300,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
entity_ids = [
|
||||
state.entity_id for state in self.hass.states.async_all() \
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
_LOGGER.debug(f"watching entities: {entity_ids}")
|
||||
@@ -434,13 +437,21 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
_LOGGER.debug(f"Options: {entity_options}")
|
||||
|
||||
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
tools = None
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
|
||||
|
||||
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
||||
response_format: Optional[ChatCompletionRequestResponseFormat] = None
|
||||
if response_json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": response_json_schema,
|
||||
}
|
||||
|
||||
chat_completion = self.models[model_name].create_chat_completion(
|
||||
messages,
|
||||
tools=tools,
|
||||
@@ -452,6 +463,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
max_tokens=max_tokens,
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||
@@ -464,5 +476,5 @@ class LlamaCppClient(LocalLLMClient):
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
yield content, tool_calls
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
"""Defines the ollama compatible agent"""
|
||||
"""Defines the Ollama compatible agent backed by the official python client."""
|
||||
from __future__ import annotations
|
||||
from warnings import deprecated
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
|
||||
import ssl
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
import certifi
|
||||
import httpx
|
||||
from ollama import AsyncClient, ChatResponse, ResponseError
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
@@ -23,119 +24,167 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_MIN_P,
|
||||
CONF_ENABLE_THINK_MODE,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_RESPONSE_JSON_SCHEMA,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_ENABLE_THINK_MODE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
||||
)
|
||||
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@deprecated("Use the built-in Ollama integration instead")
|
||||
|
||||
def _normalize_path(path: str | None) -> str:
|
||||
if not path:
|
||||
return ""
|
||||
trimmed = str(path).strip("/")
|
||||
return f"/{trimmed}" if trimmed else ""
|
||||
|
||||
|
||||
def _build_default_ssl_context() -> ssl.SSLContext:
|
||||
context = ssl.create_default_context()
|
||||
try:
|
||||
context.load_verify_locations(certifi.where())
|
||||
except OSError as err:
|
||||
_LOGGER.debug("Failed to load certifi bundle for Ollama client: %s", err)
|
||||
return context
|
||||
|
||||
class OllamaAPIClient(LocalLLMClient):
|
||||
api_host: str
|
||||
api_key: Optional[str]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
||||
super().__init__(hass, client_options)
|
||||
base_path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
self.api_host = format_url(
|
||||
hostname=client_options[CONF_HOST],
|
||||
port=client_options[CONF_PORT],
|
||||
ssl=client_options[CONF_SSL],
|
||||
path=client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
path=base_path,
|
||||
)
|
||||
self.api_key = client_options.get(CONF_OPENAI_API_KEY) or None
|
||||
self._headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else None
|
||||
self._ssl_context = _build_default_ssl_context() if client_options.get(CONF_SSL) else None
|
||||
|
||||
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
|
||||
def _build_client(self, *, timeout: float | int | httpx.Timeout | None = None) -> AsyncClient:
|
||||
timeout_config: httpx.Timeout | float | None = timeout
|
||||
if isinstance(timeout, (int, float)):
|
||||
timeout_config = httpx.Timeout(timeout)
|
||||
|
||||
return AsyncClient(
|
||||
host=self.api_host,
|
||||
headers=self._headers,
|
||||
timeout=timeout_config,
|
||||
verify=self._ssl_context,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name(client_options: dict[str, Any]):
|
||||
host = client_options[CONF_HOST]
|
||||
port = client_options[CONF_PORT]
|
||||
ssl = client_options[CONF_SSL]
|
||||
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
|
||||
path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
|
||||
|
||||
@staticmethod
|
||||
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
|
||||
headers = {}
|
||||
api_key = user_input.get(CONF_OPENAI_API_KEY)
|
||||
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
base_path = _normalize_path(user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
timeout_config: httpx.Timeout | float | None = httpx.Timeout(5)
|
||||
|
||||
verify_context = None
|
||||
if user_input.get(CONF_SSL):
|
||||
verify_context = await hass.async_add_executor_job(_build_default_ssl_context)
|
||||
|
||||
client = AsyncClient(
|
||||
host=format_url(
|
||||
hostname=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=base_path,
|
||||
),
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
|
||||
timeout=timeout_config,
|
||||
verify=verify_context,
|
||||
)
|
||||
|
||||
try:
|
||||
session = async_get_clientsession(hass)
|
||||
async with session.get(
|
||||
format_url(
|
||||
hostname=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=f"/{api_base_path}/api/tags"
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.ok:
|
||||
return None
|
||||
else:
|
||||
return f"HTTP Status {response.status}"
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
await client.list()
|
||||
except httpx.TimeoutException:
|
||||
return "Connection timed out"
|
||||
except ResponseError as err:
|
||||
return f"HTTP Status {err.status_code}: {err.error}"
|
||||
except ConnectionError as err:
|
||||
return str(err)
|
||||
|
||||
return None
|
||||
|
||||
async def async_get_available_models(self) -> List[str]:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
client = self._build_client(timeout=5)
|
||||
try:
|
||||
response = await client.list()
|
||||
except httpx.TimeoutException as err:
|
||||
raise HomeAssistantError("Timed out while fetching models from the Ollama server") from err
|
||||
except (ResponseError, ConnectionError) as err:
|
||||
raise HomeAssistantError(f"Failed to fetch models from the Ollama server: {err}") from err
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/api/tags",
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
models: List[str] = []
|
||||
for model in getattr(response, "models", []) or []:
|
||||
candidate = getattr(model, "name", None) or getattr(model, "model", None)
|
||||
if candidate:
|
||||
models.append(candidate)
|
||||
|
||||
return [x["name"] for x in models_result["models"]]
|
||||
return models
|
||||
|
||||
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
# if response_json["prompt_eval_count"] + max_tokens > context_len:
|
||||
# self._warn_context_size()
|
||||
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[dict]]]:
|
||||
content = response_chunk.message.content
|
||||
raw_tool_calls = response_chunk.message.tool_calls
|
||||
|
||||
if "response" in response_json:
|
||||
response = response_json["response"]
|
||||
tool_calls = None
|
||||
stop_reason = None
|
||||
if response_json["done"] not in ["true", True]:
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
if raw_tool_calls:
|
||||
# return openai formatted tool calls
|
||||
tool_calls = [{
|
||||
"function": {
|
||||
"name": call.function.name,
|
||||
"arguments": call.function.arguments,
|
||||
}
|
||||
} for call in raw_tool_calls]
|
||||
else:
|
||||
response = response_json["message"]["content"]
|
||||
raw_tool_calls = response_json["message"].get("tool_calls")
|
||||
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
|
||||
stop_reason = response_json.get("done_reason")
|
||||
tool_calls = None
|
||||
|
||||
# _LOGGER.debug(f"{response=} {tool_calls=}")
|
||||
return content, tool_calls
|
||||
|
||||
return response, tool_calls
|
||||
@staticmethod
|
||||
def _format_keep_alive(value: Any) -> Any:
|
||||
as_text = str(value).strip()
|
||||
return 0 if as_text in {"0", "0.0"} else f"{as_text}m"
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
def _generate_stream(
|
||||
self,
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -145,58 +194,48 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
|
||||
think_mode = entity_options.get(CONF_ENABLE_THINK_MODE, DEFAULT_ENABLE_THINK_MODE)
|
||||
json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
request_params = {
|
||||
"model": model_name,
|
||||
"stream": True,
|
||||
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
options = {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
"min_p": entity_options.get(CONF_MIN_P, DEFAULT_MIN_P),
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
messages = get_oai_formatted_messages(conversation, tool_args_to_str=False)
|
||||
tools = None
|
||||
if llm_api and not legacy_tool_calling:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
keep_alive_payload = self._format_keep_alive(keep_alive)
|
||||
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
||||
client = self._build_client(timeout=timeout)
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
chunk = await response.content.readline()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
|
||||
except aiohttp.ClientError as err:
|
||||
format_option = entity_options.get(CONF_RESPONSE_JSON_SCHEMA, "json" if json_mode else None)
|
||||
stream = await client.chat(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
think=think_mode,
|
||||
format=format_option,
|
||||
options=options,
|
||||
keep_alive=keep_alive_payload,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
yield self._extract_response(chunk)
|
||||
except httpx.TimeoutException as err:
|
||||
raise HomeAssistantError(
|
||||
"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
||||
) from err
|
||||
except (ResponseError, ConnectionError) as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
@@ -10,6 +10,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
|
||||
from homeassistant.components import conversation, ai_task
|
||||
from homeassistant.data_entry_flow import (
|
||||
AbortFlow,
|
||||
)
|
||||
@@ -46,6 +47,11 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_AI_TASK_PROMPT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
DEFAULT_AI_TASK_RETRIES,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
@@ -150,7 +156,7 @@ from .const import (
|
||||
DEFAULT_OPTIONS,
|
||||
option_overrides,
|
||||
RECOMMENDED_CHAT_MODELS,
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
)
|
||||
|
||||
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
|
||||
@@ -225,7 +231,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Local LLM Conversation."""
|
||||
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 1
|
||||
MINOR_VERSION = 2
|
||||
|
||||
install_wheel_task = None
|
||||
install_wheel_error = None
|
||||
@@ -399,8 +405,8 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {
|
||||
"conversation": LocalLLMSubentryFlowHandler,
|
||||
# "ai_task_data": LocalLLMSubentryFlowHandler,
|
||||
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
|
||||
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -583,40 +589,13 @@ def local_llama_config_option_schema(
|
||||
backend_type: str,
|
||||
subentry_type: str,
|
||||
) -> dict:
|
||||
|
||||
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
|
||||
|
||||
|
||||
result: dict = {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
|
||||
default=options.get(CONF_PROMPT, default_prompt),
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
|
||||
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
||||
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
||||
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
|
||||
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
|
||||
vol.Required(
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
|
||||
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
): TextSelector(TextSelectorConfig(multiple=True)),
|
||||
vol.Required(
|
||||
CONF_THINKING_PREFIX,
|
||||
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
|
||||
@@ -644,9 +623,114 @@ def local_llama_config_option_schema(
|
||||
): bool,
|
||||
}
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
if subentry_type == ai_task.DOMAIN:
|
||||
result.update({
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)},
|
||||
default=options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT),
|
||||
): TemplateSelector(),
|
||||
vol.Required(
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
|
||||
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
|
||||
): SelectSelector(SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(value="none", label="None"),
|
||||
SelectOptionDict(value="structure", label="Structured output"),
|
||||
SelectOptionDict(value="tool", label="Tool call"),
|
||||
],
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
vol.Required(
|
||||
CONF_AI_TASK_RETRIES,
|
||||
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
|
||||
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
|
||||
})
|
||||
elif subentry_type == conversation.DOMAIN:
|
||||
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
|
||||
apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
]
|
||||
result.update({
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
|
||||
default=options.get(CONF_PROMPT, default_prompt),
|
||||
): TemplateSelector(),
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
||||
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
|
||||
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
|
||||
vol.Required(
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
|
||||
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
): TextSelector(TextSelectorConfig(multiple=True)),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default=None,
|
||||
): SelectSelector(SelectSelectorConfig(options=apis, multiple=True)),
|
||||
vol.Optional(
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
|
||||
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
|
||||
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
|
||||
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
|
||||
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
|
||||
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
): int,
|
||||
})
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
if subentry_type == conversation.DOMAIN:
|
||||
result.update({
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
|
||||
default=DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
|
||||
default=DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
|
||||
})
|
||||
result.update({
|
||||
vol.Required(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
@@ -671,16 +755,6 @@ def local_llama_config_option_schema(
|
||||
description={"suggested_value": options.get(CONF_TYPICAL_P)},
|
||||
default=DEFAULT_TYPICAL_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
|
||||
default=DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
|
||||
default=DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
|
||||
# TODO: add rope_scaling_type
|
||||
vol.Required(
|
||||
CONF_CONTEXT_LENGTH,
|
||||
@@ -879,60 +953,13 @@ def local_llama_config_option_schema(
|
||||
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
|
||||
})
|
||||
|
||||
if subentry_type == "conversation":
|
||||
apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
result.update({
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
||||
vol.Optional(
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
|
||||
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
|
||||
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
|
||||
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
|
||||
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
|
||||
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
): int,
|
||||
})
|
||||
elif subentry_type == "ai_task_data":
|
||||
pass # no additional options for ai_task_data for now
|
||||
|
||||
# sort the options
|
||||
global_order = [
|
||||
# general
|
||||
CONF_LLM_HASS_API,
|
||||
CONF_PROMPT,
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_MAX_TOKENS,
|
||||
# sampling parameters
|
||||
@@ -1122,8 +1149,16 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
description_placeholders = {}
|
||||
entry = self._get_entry()
|
||||
backend_type = entry.data[CONF_BACKEND_TYPE]
|
||||
is_ai_task = self._subentry_type == ai_task.DOMAIN
|
||||
|
||||
if CONF_PROMPT not in self.model_config:
|
||||
if is_ai_task:
|
||||
if CONF_PROMPT not in self.model_config:
|
||||
self.model_config[CONF_PROMPT] = DEFAULT_AI_TASK_PROMPT
|
||||
if CONF_AI_TASK_RETRIES not in self.model_config:
|
||||
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
|
||||
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
|
||||
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
|
||||
elif CONF_PROMPT not in self.model_config:
|
||||
# determine selected language from model config or parent options
|
||||
selected_language = self.model_config.get(
|
||||
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
|
||||
@@ -1156,20 +1191,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
)
|
||||
|
||||
if user_input:
|
||||
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
||||
errors["base"] = "sys_refresh_caching_enabled"
|
||||
if not is_ai_task:
|
||||
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
||||
errors["base"] = "sys_refresh_caching_enabled"
|
||||
|
||||
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
||||
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
||||
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_icl_file"
|
||||
description_placeholders["filename"] = filename
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_icl_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
|
||||
for key in (
|
||||
@@ -1178,6 +1214,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_AI_TASK_RETRIES,
|
||||
):
|
||||
if key in user_input:
|
||||
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
|
||||
@@ -1187,10 +1224,6 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
# validate input
|
||||
schema(user_input)
|
||||
self.model_config.update(user_input)
|
||||
|
||||
# clear LLM API if 'none' selected
|
||||
if self.model_config.get(CONF_LLM_HASS_API) == "none":
|
||||
self.model_config.pop(CONF_LLM_HASS_API, None)
|
||||
|
||||
return await self.async_step_finish()
|
||||
except Exception:
|
||||
|
||||
@@ -8,6 +8,11 @@ SERVICE_TOOL_NAME = "HassCallService"
|
||||
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_AI_TASK_PROMPT = "You are a task-specific assistant. Follow the task instructions and return the requested data."
|
||||
CONF_AI_TASK_RETRIES = "ai_task_retries"
|
||||
DEFAULT_AI_TASK_RETRIES = 1
|
||||
CONF_AI_TASK_EXTRACTION_METHOD = "ai_task_extraction_method"
|
||||
DEFAULT_AI_TASK_EXTRACTION_METHOD = "structure"
|
||||
PERSONA_PROMPTS = {
|
||||
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
|
||||
"de": "Du bist \u201eAl\u201c, ein hilfreicher KI-Assistent, der die Ger\u00e4te in einem Haus steuert. F\u00fchren Sie die folgende Aufgabe gem\u00e4\u00df den Anweisungen durch oder beantworten Sie die folgende Frage nur mit den bereitgestellten Informationen.",
|
||||
@@ -104,6 +109,8 @@ CONF_TEMPERATURE = "temperature"
|
||||
DEFAULT_TEMPERATURE = 0.1
|
||||
CONF_REQUEST_TIMEOUT = "request_timeout"
|
||||
DEFAULT_REQUEST_TIMEOUT = 90
|
||||
CONF_ENABLE_THINK_MODE = "enable_think_mode"
|
||||
DEFAULT_ENABLE_THINK_MODE = False
|
||||
CONF_BACKEND_TYPE = "model_backend"
|
||||
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
|
||||
@@ -185,7 +192,8 @@ DEFAULT_GENERIC_OPENAI_PATH = "v1"
|
||||
CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
|
||||
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
|
||||
CONF_CONTEXT_LENGTH = "context_length"
|
||||
DEFAULT_CONTEXT_LENGTH = 2048
|
||||
DEFAULT_CONTEXT_LENGTH = 8192
|
||||
CONF_RESPONSE_JSON_SCHEMA = "response_json_schema"
|
||||
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE = 512
|
||||
CONF_LLAMACPP_THREAD_COUNT = "n_threads"
|
||||
|
||||
@@ -17,6 +17,7 @@ from custom_components.llama_conversation.utils import MalformedToolCallExceptio
|
||||
|
||||
from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
@@ -39,6 +40,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
|
||||
if subentry.subentry_type != conversation.DOMAIN:
|
||||
continue
|
||||
|
||||
if CONF_CHAT_MODEL not in subentry.data:
|
||||
_LOGGER.warning("Conversation subentry %s missing required config key %s, You must delete the model and re-create it.", subentry.subentry_id, CONF_CHAT_MODEL)
|
||||
continue
|
||||
|
||||
# create one agent entity per conversation subentry
|
||||
agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data)
|
||||
|
||||
|
||||
@@ -5,15 +5,16 @@ import csv
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
|
||||
import re
|
||||
from typing import Literal, Any, List, Dict, Optional, Sequence, Tuple, AsyncIterator, Generator, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr, entity
|
||||
from homeassistant.util import color
|
||||
@@ -41,8 +42,6 @@ from .const import (
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOOL_CALL_SUFFIX,
|
||||
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -184,28 +183,25 @@ class LocalLLMClient:
|
||||
_LOGGER.debug("Received chunk: %s", input_chunk)
|
||||
|
||||
tool_calls = input_chunk.tool_calls
|
||||
# fix tool calls for the service tool
|
||||
if tool_calls and chat_log.llm_api and chat_log.llm_api.api.id == HOME_LLM_API_ID:
|
||||
tool_calls = [
|
||||
llm.ToolInput(
|
||||
tool_name=SERVICE_TOOL_NAME,
|
||||
tool_args={**tc.tool_args, "service": tc.tool_name}
|
||||
) for tc in tool_calls
|
||||
]
|
||||
if tool_calls and not chat_log.llm_api:
|
||||
raise HomeAssistantError("Model attempted to call a tool but no LLM API was provided")
|
||||
|
||||
yield conversation.AssistantContentDeltaDict(
|
||||
content=input_chunk.response,
|
||||
tool_calls=tool_calls
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
|
||||
|
||||
async def _async_parse_completion(
|
||||
self, llm_api: llm.APIInstance | None,
|
||||
async def _async_stream_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[Sequence[str | dict]]]]] = None,
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Parse streaming completion with tool calls from the backend. Accepts either a sync or async token generator."""
|
||||
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
@@ -236,7 +232,7 @@ class LocalLLMClient:
|
||||
cur_match_length = 0
|
||||
async for chunk in token_generator:
|
||||
# _LOGGER.debug(f"Handling chunk: {chunk} {in_thinking=} {in_tool_call=} {last_5_tokens=}")
|
||||
tool_calls: Optional[List[str | llm.ToolInput | dict]]
|
||||
tool_calls: Optional[List[str | dict]]
|
||||
content, tool_calls = chunk
|
||||
|
||||
if not tool_calls:
|
||||
@@ -289,31 +285,73 @@ class LocalLLMClient:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
for raw_tool_call in tool_calls:
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
|
||||
else:
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
|
||||
if len(parsed_tool_calls) > 0:
|
||||
result.tool_calls = parsed_tool_calls
|
||||
|
||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||
yield result
|
||||
|
||||
async def _async_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
completion: str | dict) -> TextGenerationResult:
|
||||
"""Parse completion with tool calls from the backend."""
|
||||
think_prefix = entity_options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = entity_options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
think_regex = re.compile(re.escape(think_prefix) + "(.*?)" + re.escape(think_suffix), re.DOTALL)
|
||||
tool_prefix = entity_options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
tool_suffix = entity_options.get(CONF_TOOL_CALL_SUFFIX, DEFAULT_TOOL_CALL_SUFFIX)
|
||||
tool_regex = re.compile(re.escape(tool_prefix) + "(.*?)" + re.escape(tool_suffix), re.DOTALL)
|
||||
|
||||
if isinstance(completion, dict):
|
||||
completion = str(completion.get("response", ""))
|
||||
|
||||
# Remove thinking blocks, and extract tool calls
|
||||
tool_calls = tool_regex.findall(completion)
|
||||
completion = think_regex.sub("", completion)
|
||||
completion = tool_regex.sub("", completion)
|
||||
|
||||
to_say = ""
|
||||
parsed_tool_calls: list[llm.ToolInput] = []
|
||||
if len(tool_calls) and not llm_api:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
for raw_tool_call in tool_calls:
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
else:
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, agent_id)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], agent_id)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
parsed_tool_calls.append(tool_call)
|
||||
|
||||
return TextGenerationResult(
|
||||
response=completion + (to_say or ""),
|
||||
tool_calls=parsed_tool_calls,
|
||||
)
|
||||
|
||||
def _async_get_all_exposed_domains(self) -> list[str]:
|
||||
"""Gather all exposed domains"""
|
||||
domains = set()
|
||||
for state in self.hass.states.async_all():
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
|
||||
domains.add(state.domain)
|
||||
|
||||
return list(domains)
|
||||
@@ -326,7 +364,7 @@ class LocalLLMClient:
|
||||
area_registry = ar.async_get(self.hass)
|
||||
|
||||
for state in self.hass.states.async_all():
|
||||
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
if not async_should_expose(self.hass, conversation.DOMAIN, state.entity_id):
|
||||
continue
|
||||
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "Local LLMs",
|
||||
"version": "0.4.4",
|
||||
"version": "0.4.5",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"dependencies": ["conversation", "ai_task"],
|
||||
"after_dependencies": ["assist_pipeline", "intent"],
|
||||
"documentation": "https://github.com/acon96/home-llm",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_polling",
|
||||
"requirements": [
|
||||
"huggingface-hub>=0.23.0",
|
||||
"webcolors>=24.8.0"
|
||||
"webcolors>=24.8.0",
|
||||
"ollama>=0.5.1"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@
|
||||
"model_parameters": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
"llm_hass_api": "Selected LLM API(s)",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
@@ -109,7 +109,7 @@
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
|
||||
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
|
||||
@@ -124,7 +124,7 @@
|
||||
"reconfigure": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
"llm_hass_api": "Selected LLM API(s)",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
@@ -163,7 +163,7 @@
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM (v1-3) model then select 'Home Assistant Services'",
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
|
||||
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
|
||||
@@ -177,7 +177,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"ai_task": {
|
||||
"initiate_flow": {
|
||||
"user": "Add AI Task Handler",
|
||||
"reconfigure": "Reconfigure AI Task Handler"
|
||||
@@ -246,7 +246,9 @@
|
||||
"tool_call_prefix": "Tool Call Prefix",
|
||||
"tool_call_suffix": "Tool Call Suffix",
|
||||
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts",
|
||||
"ai_task_extraction_method": "Structured Data Extraction Method",
|
||||
"ai_task_retries": "Retry attempts for structured data extraction"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
@@ -255,7 +257,8 @@
|
||||
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
|
||||
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
|
||||
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
|
||||
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
|
||||
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3).",
|
||||
"ai_task_extraction_method": "Select the method used to extract structured data from the model's response. 'Structured Output' tells the backend to force the model to produce output following the provided JSON Schema; 'Tool Calling' provides a tool to the model that should be called with the appropriate arguments that match the desired output structure."
|
||||
},
|
||||
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
|
||||
"title": "Configure the selected model"
|
||||
@@ -263,7 +266,7 @@
|
||||
"reconfigure": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
"llm_hass_api": "Selected LLM API(s)",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
|
||||
@@ -24,7 +24,7 @@ from homeassistant.requirements import pip_kwargs
|
||||
from homeassistant.util import color
|
||||
from homeassistant.util.package import install_package, is_installed
|
||||
|
||||
from voluptuous_openapi import convert
|
||||
from voluptuous_openapi import convert as convert_to_openapi
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
@@ -32,7 +32,7 @@ from .const import (
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -275,26 +275,30 @@ def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, spe
|
||||
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ]
|
||||
|
||||
else:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in llm_api.tools ]
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
|
||||
result: List[ChatCompletionTool] = []
|
||||
|
||||
for tool in llm_api.tools:
|
||||
# when combining with home assistant llm APIs, it adds a prefix to differentiate tools; compare against the suffix here
|
||||
if tool.name.endswith(SERVICE_TOOL_NAME):
|
||||
result.extend([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ])
|
||||
else:
|
||||
result.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@@ -396,41 +400,44 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
|
||||
def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
else:
|
||||
schema_to_validate = vol.Schema({
|
||||
# try to validate either format
|
||||
is_services_tool_call = False
|
||||
try:
|
||||
base_schema_to_validate = vol.Schema({
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): vol.Union(str, dict),
|
||||
})
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
base_schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
try:
|
||||
home_llm_schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
home_llm_schema_to_validate(parsed_tool_call)
|
||||
is_services_tool_call = True
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", ""))
|
||||
args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
|
||||
tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
|
||||
|
||||
if isinstance(args_dict, str):
|
||||
if not args_dict.strip():
|
||||
|
||||
@@ -6,4 +6,5 @@ home-assistant-intents
|
||||
# testing requirements
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component==0.13.260
|
||||
# NOTE this must match the version of Home Assistant used for testing
|
||||
pytest-homeassistant-custom-component==0.13.272
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
huggingface-hub>=0.23.0
|
||||
webcolors>=24.8.0
|
||||
ollama>=0.5.1
|
||||
|
||||
@@ -28,8 +28,8 @@ Start by installing system dependencies:
|
||||
Then create a Python virtual environment and install all necessary library:
|
||||
```
|
||||
python3 -m venv .generate_data
|
||||
source ./.generate_data/bin/activate
|
||||
pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0
|
||||
source .generate_data/bin/activate
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
## Generating the dataset from piles
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
datasets>=3.2.0
|
||||
webcolors>=1.13
|
||||
webcolors>=24.8.0
|
||||
pandas>=2.2.3
|
||||
deep-translator>=1.11.4
|
||||
langcodes>=3.5.0
|
||||
|
||||
52
docs/AI Tasks.md
Normal file
52
docs/AI Tasks.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Using AI Tasks
|
||||
The AI Tasks feature allows you to define structured tasks that your local LLM can perform. These tasks can be integrated into Home Assistant automations and scripts, enabling you to generate dynamic content based on specific prompts and instructions.
|
||||
|
||||
## Setting up an AI Task Handler
|
||||
Setting up a task handler is similar to setting up a conversation agent. You can choose to run the model directly within Home Assistant using `llama-cpp-python`, or you can use an external backend like Ollama. See the [Setup Guide](./docs/Setup.md) for detailed instructions on configuring your AI Task handler.
|
||||
|
||||
The specific configuration options for AI Tasks are:
|
||||
| Option Name | Description |
|
||||
|-----------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Structured Data Extraction Method | Choose how the AI Task should extract structured data from the model's output. Options include `structured_output` and `tool`. |
|
||||
| Data Extraction Retry Count | The number of times to retry data extraction if the initial attempt fails. Useful when models can produce incorrect tool responses. |
|
||||
|
||||
If no structured data extraction method is specified, then the task entity will always return raw text.
|
||||
|
||||
## Using an AI Task in a Script or Automation
|
||||
To use an AI Task in a Home Assistant script or automation, you can utilize the `ai_task.generate_data` action. This action allows you to specify the task name, instructions, and the structure of the expected output. Below is an example of a script that generates a joke about a smart device in your home.
|
||||
|
||||
**Device Joke Script:**
|
||||
```yaml
|
||||
sequence:
|
||||
- action: ai_task.generate_data
|
||||
data:
|
||||
task_name: Device Joke Generation
|
||||
instructions: |
|
||||
Write a funny joke about one of the smart devices in my home.
|
||||
Here are all of the smart devices I have:
|
||||
{% for device in states | rejectattr('domain', 'in', ['update', 'event']) -%}
|
||||
- {{ device.name }} ({{device.domain}})
|
||||
{% endfor %}
|
||||
# You MUST set this to your own LLM entity ID if you do not set a default one in HA Settings
|
||||
# entity_id: ai_task.unsloth_qwen3_0_6b_gguf_unsloth_qwen3_0_6b_gguf
|
||||
structure:
|
||||
joke_setup:
|
||||
description: The beginning of a joke about a smart device in the home
|
||||
required: true
|
||||
selector:
|
||||
text: null
|
||||
joke_punchline:
|
||||
description: The punchline of the same joke about the smart device
|
||||
required: true
|
||||
selector:
|
||||
text: null
|
||||
response_variable: joke_output
|
||||
- action: notify.persistent_notification
|
||||
data:
|
||||
message: |-
|
||||
{{ joke_output.data.joke_setup }}
|
||||
...
|
||||
{{ joke_output.data.joke_punchline }}
|
||||
alias: Device Joke
|
||||
description: "Generates a funny joke about one of the smart devices in the home."
|
||||
```
|
||||
@@ -4,7 +4,7 @@ datasets>=3.2.0
|
||||
peft>=0.14.0
|
||||
bitsandbytes>=0.45.2
|
||||
trl>=0.14.0
|
||||
webcolors>=1.13
|
||||
webcolors>=24.8.0
|
||||
pandas>=2.2.3
|
||||
# flash-attn
|
||||
sentencepiece>=0.2.0
|
||||
|
||||
@@ -1,763 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
|
||||
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT_BASE,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
DOMAIN,
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
DEFAULT_OPTIONS,
|
||||
)
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
CONF_LLM_HASS_API
|
||||
)
|
||||
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class WarnDict(dict):
|
||||
def get(self, _key, _default=None):
|
||||
if _key in self:
|
||||
return self[_key]
|
||||
|
||||
_LOGGER.warning(f"attempting to get unset dictionary key {_key}")
|
||||
|
||||
return _default
|
||||
|
||||
class MockConfigEntry:
|
||||
def __init__(self, entry_id='test_entry_id', data={}, options={}):
|
||||
self.entry_id = entry_id
|
||||
self.data = WarnDict(data)
|
||||
# Use a mutable dict for options in tests
|
||||
self.options = WarnDict(dict(options))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry():
|
||||
yield MockConfigEntry(
|
||||
data={
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
CONF_BACKEND_TYPE: DEFAULT_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE: "/config/models/some-model.q4_k_m.gguf",
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_OPENAI_API_KEY: "OpenAI-API-Key",
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY: "Text-Gen-Webui-Admin-Key"
|
||||
},
|
||||
options={
|
||||
**DEFAULT_OPTIONS,
|
||||
CONF_LLM_HASS_API: LLM_API_ASSIST,
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE,
|
||||
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
|
||||
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
|
||||
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
llama_instance_mock = MagicMock()
|
||||
llama_class_mock = MagicMock()
|
||||
llama_class_mock.return_value = llama_instance_mock
|
||||
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
|
||||
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
|
||||
install_llama_cpp_python_mock.return_value = True
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
# template_mock.side_affect = lambda template, _: jinja2.Template(template)
|
||||
generate_mock = llama_instance_mock.generate
|
||||
generate_mock.return_value = list(range(20))
|
||||
|
||||
detokenize_mock = llama_instance_mock.detokenize
|
||||
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
})).encode()
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = LlamaCppAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"llama_class": llama_class_mock,
|
||||
"tokenize": llama_instance_mock.tokenize,
|
||||
"generate": generate_mock,
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
# TODO: test base llama agent (ICL loading other languages)
|
||||
|
||||
async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
local_llama_agent: LlamaCppAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
local_llama_agent, all_mocks = local_llama_agent_fixture
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
|
||||
min_p=local_llama_agent.entry.options[CONF_MIN_P],
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options then apply them
|
||||
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
|
||||
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
|
||||
local_llama_agent.entry.options[CONF_TOP_K] = 20
|
||||
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
|
||||
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
|
||||
|
||||
local_llama_agent._update_options()
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
|
||||
min_p=local_llama_agent.entry.options[CONF_MIN_P],
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
|
||||
get_clientsession.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = OllamaAPIAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession.get,
|
||||
"requests_post": get_clientsession.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_ollama_agent(ollama_agent_fixture):
|
||||
|
||||
ollama_agent: OllamaAPIAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
ollama_agent, all_mocks = ollama_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/api/tags",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"done": True,
|
||||
"context": [1, 2, 3],
|
||||
"total_duration": 4648158584,
|
||||
"load_duration": 4071084,
|
||||
"prompt_eval_count": 36,
|
||||
"prompt_eval_duration": 439038000,
|
||||
"eval_count": 180,
|
||||
"eval_duration": 4196918000
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/generate",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"prompt": ANY,
|
||||
"raw": True
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
|
||||
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
|
||||
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
ollama_agent.entry.options[CONF_TOP_K] = 20
|
||||
ollama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/chat",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"messages": ANY
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
|
||||
get_clientsession_mock.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = TextGenerationWebuiAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
text_generation_webui_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/internal/model/info",
|
||||
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"preset": "Some Preset",
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# change options
|
||||
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
|
||||
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
"object": "chat.completions",
|
||||
"created": 1677652288,
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = GenericOpenAIAPIAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
|
||||
generic_openai_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
generic_openai_agent, all_mocks = generic_openai_agent_fixture
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
144
tests/llama_conversation/test_ai_task.py
Normal file
144
tests/llama_conversation/test_ai_task.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for AI Task extraction behavior."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.ai_task import (
|
||||
LocalLLMTaskEntity,
|
||||
ResultExtractionMethod,
|
||||
)
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_AI_TASK_EXTRACTION_METHOD,
|
||||
)
|
||||
from custom_components.llama_conversation.entity import TextGenerationResult
|
||||
|
||||
|
||||
class DummyGenTask:
|
||||
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
|
||||
self.name = name
|
||||
self.instructions = instructions
|
||||
self.attachments = attachments or []
|
||||
self.structure = structure
|
||||
|
||||
|
||||
class DummyChatLog:
|
||||
def __init__(self, content=None):
|
||||
self.content = content or []
|
||||
self.conversation_id = "conv-id"
|
||||
self.llm_api = None
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, result: TextGenerationResult):
|
||||
self._result = result
|
||||
|
||||
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
|
||||
return prompt_template
|
||||
|
||||
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
|
||||
return False
|
||||
|
||||
async def _generate(self, _messages, _llm_api, _entity_id, _options):
|
||||
return self._result
|
||||
|
||||
|
||||
class DummyTaskEntity(LocalLLMTaskEntity):
|
||||
def __init__(self, hass, client, options):
|
||||
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
|
||||
self.hass = hass
|
||||
self.client = client
|
||||
self._runtime_options = options
|
||||
self.entity_id = "ai_task.test"
|
||||
self.entry_id = "entry"
|
||||
self.subentry_id = "subentry"
|
||||
self._attr_supported_features = 0
|
||||
|
||||
@property
|
||||
def runtime_options(self):
|
||||
return self._runtime_options
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_extraction_returns_raw_text(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="raw text")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask()
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == "raw text"
|
||||
assert chat_log.llm_api is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_success(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response='{"foo": 1}')),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == {"foo": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_invalid_json_raises(hass):
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="not-json")),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_success(hass):
|
||||
tool_call = llm.ToolInput("submit_response", {"value": 42})
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
||||
|
||||
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
assert result.data == {"value": 42}
|
||||
assert chat_log.llm_api is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_extraction_missing_tool_args_raises(hass):
|
||||
class DummyToolCall:
|
||||
def __init__(self, tool_args=None):
|
||||
self.tool_args = tool_args
|
||||
|
||||
tool_call = DummyToolCall(tool_args=None)
|
||||
entity = DummyTaskEntity(
|
||||
hass,
|
||||
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
|
||||
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
||||
)
|
||||
chat_log = DummyChatLog()
|
||||
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
||||
|
||||
105
tests/llama_conversation/test_basic.py
Normal file
105
tests/llama_conversation/test_basic.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Lightweight smoke tests for backend helpers.
|
||||
|
||||
These avoid backend calls and only cover helper utilities to keep the suite green
|
||||
while the integration evolves. No integration code is modified.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
|
||||
from custom_components.llama_conversation.backends.llamacpp import snapshot_settings
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIClient, _normalize_path
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def hass_defaults(hass):
|
||||
return hass
|
||||
|
||||
|
||||
def test_snapshot_settings_defaults():
|
||||
options = {CONF_CHAT_MODEL: "test-model"}
|
||||
snap = snapshot_settings(options)
|
||||
assert snap[CONF_CONTEXT_LENGTH] == DEFAULT_CONTEXT_LENGTH
|
||||
assert snap[CONF_LLAMACPP_BATCH_SIZE] == DEFAULT_LLAMACPP_BATCH_SIZE
|
||||
assert snap[CONF_LLAMACPP_THREAD_COUNT] == DEFAULT_LLAMACPP_THREAD_COUNT
|
||||
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
|
||||
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] == DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
|
||||
assert snap[CONF_GBNF_GRAMMAR_FILE] == DEFAULT_GBNF_GRAMMAR_FILE
|
||||
assert snap[CONF_PROMPT_CACHING_ENABLED] == DEFAULT_PROMPT_CACHING_ENABLED
|
||||
|
||||
|
||||
def test_snapshot_settings_overrides():
|
||||
options = {
|
||||
CONF_CONTEXT_LENGTH: 4096,
|
||||
CONF_LLAMACPP_BATCH_SIZE: 64,
|
||||
CONF_LLAMACPP_THREAD_COUNT: 6,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
|
||||
CONF_GBNF_GRAMMAR_FILE: "custom.gbnf",
|
||||
CONF_PROMPT_CACHING_ENABLED: True,
|
||||
}
|
||||
snap = snapshot_settings(options)
|
||||
assert snap[CONF_CONTEXT_LENGTH] == 4096
|
||||
assert snap[CONF_LLAMACPP_BATCH_SIZE] == 64
|
||||
assert snap[CONF_LLAMACPP_THREAD_COUNT] == 6
|
||||
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == 3
|
||||
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] is True
|
||||
assert snap[CONF_GBNF_GRAMMAR_FILE] == "custom.gbnf"
|
||||
assert snap[CONF_PROMPT_CACHING_ENABLED] is True
|
||||
|
||||
|
||||
def test_ollama_keep_alive_formatting():
|
||||
assert OllamaAPIClient._format_keep_alive("0") == 0
|
||||
assert OllamaAPIClient._format_keep_alive("0.0") == 0
|
||||
assert OllamaAPIClient._format_keep_alive(5) == "5m"
|
||||
assert OllamaAPIClient._format_keep_alive("15") == "15m"
|
||||
|
||||
|
||||
def test_generic_openai_name_and_path(hass_defaults):
|
||||
client = GenericOpenAIAPIClient(
|
||||
hass_defaults,
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
CONF_CHAT_MODEL: "demo",
|
||||
},
|
||||
)
|
||||
name = client.get_name(
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
}
|
||||
)
|
||||
assert "Generic OpenAI" in name
|
||||
assert "localhost" in name
|
||||
|
||||
|
||||
def test_normalize_path_helper():
|
||||
assert _normalize_path(None) == ""
|
||||
assert _normalize_path("") == ""
|
||||
assert _normalize_path("/v1/") == "/v1"
|
||||
assert _normalize_path("v2") == "/v2"
|
||||
@@ -1,350 +1,204 @@
|
||||
"""Config flow option schema tests to ensure options are wired per-backend."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from homeassistant import config_entries, setup
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
CONF_LLM_HASS_API,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
|
||||
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_TOOL_FORMAT,
|
||||
CONF_TOOL_MULTI_TURN_CHAT,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_MIN_P,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_PROMPT,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_THINKING_PREFIX,
|
||||
CONF_TOOL_CALL_PREFIX,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_TEMPERATURE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
DOMAIN,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_THINKING_PREFIX,
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
)
|
||||
|
||||
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
|
||||
# result = await hass.config_entries.flow.async_init(
|
||||
# DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
# )
|
||||
# assert result["type"] == FlowResultType.FORM
|
||||
# assert result["errors"] is None
|
||||
|
||||
# result2 = await hass.config_entries.flow.async_configure(
|
||||
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
|
||||
# )
|
||||
# assert result2["type"] == FlowResultType.FORM
|
||||
|
||||
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
|
||||
# result3 = await hass.config_entries.flow.async_configure(
|
||||
# result2["flow_id"],
|
||||
# TEST_DATA,
|
||||
# )
|
||||
# await hass.async_block_till_done()
|
||||
|
||||
# assert result3["type"] == "create_entry"
|
||||
# assert result3["title"] == ""
|
||||
# assert result3["data"] == {
|
||||
# # ACCOUNT_ID: TEST_DATA["account_id"],
|
||||
# # CONF_PASSWORD: TEST_DATA["password"],
|
||||
# # CONNECTION_TYPE: CLOUD,
|
||||
# }
|
||||
# assert result3["options"] == {}
|
||||
# assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@pytest.fixture
|
||||
def validate_connections_mock():
|
||||
validate_mock = MagicMock()
|
||||
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
|
||||
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
|
||||
yield validate_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry():
|
||||
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
|
||||
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
|
||||
yield mock_setup_entry
|
||||
|
||||
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
|
||||
def _schema(hass: HomeAssistant, backend: str, options: dict | None = None):
|
||||
return local_llama_config_option_schema(
|
||||
hass=hass,
|
||||
language="en",
|
||||
options=options or {},
|
||||
backend_type=backend,
|
||||
subentry_type="conversation",
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
def _get_default(schema: dict, key_name: str):
|
||||
for key in schema:
|
||||
if getattr(key, "schema", None) == key_name:
|
||||
default = getattr(key, "default", None)
|
||||
return default() if callable(default) else default
|
||||
raise AssertionError(f"Key {key_name} not found in schema")
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
def _get_suggested(schema: dict, key_name: str):
|
||||
for key in schema:
|
||||
if getattr(key, "schema", None) == key_name:
|
||||
return (getattr(key, "description", {}) or {}).get("suggested_value")
|
||||
raise AssertionError(f"Key {key_name} not found in schema")
|
||||
|
||||
|
||||
def test_schema_llama_cpp_defaults_and_overrides(hass: HomeAssistant):
|
||||
overrides = {
|
||||
CONF_CONTEXT_LENGTH: 4096,
|
||||
CONF_LLAMACPP_BATCH_SIZE: 8,
|
||||
CONF_LLAMACPP_THREAD_COUNT: 6,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
|
||||
CONF_PROMPT_CACHING_INTERVAL: 15,
|
||||
CONF_TOP_K: 12,
|
||||
CONF_TOOL_CALL_PREFIX: "<tc>",
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP, overrides)
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
expected_keys = {
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert expected_keys.issubset({getattr(k, "schema", None) for k in schema})
|
||||
|
||||
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_default(schema, CONF_LLAMACPP_BATCH_SIZE) == DEFAULT_LLAMACPP_BATCH_SIZE
|
||||
assert _get_default(schema, CONF_LLAMACPP_THREAD_COUNT) == DEFAULT_LLAMACPP_THREAD_COUNT
|
||||
assert _get_default(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
|
||||
assert _get_default(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
|
||||
assert _get_default(schema, CONF_PROMPT_CACHING_INTERVAL) == DEFAULT_PROMPT_CACHING_INTERVAL
|
||||
# suggested values should reflect overrides
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 4096
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_SIZE) == 8
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_THREAD_COUNT) == 6
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == 3
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is True
|
||||
assert _get_suggested(schema, CONF_PROMPT_CACHING_INTERVAL) == 15
|
||||
assert _get_suggested(schema, CONF_TOP_K) == 12
|
||||
assert _get_suggested(schema, CONF_TOOL_CALL_PREFIX) == "<tc>"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
# simulate incorrect settings on first try
|
||||
validate_connections_mock.side_effect = [
|
||||
("failed_to_connect", Exception("ConnectionError"), []),
|
||||
(None, None, [])
|
||||
]
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert len(result3["errors"]) == 1
|
||||
assert "base" in result3["errors"]
|
||||
assert result3["step_id"] == "remote_model"
|
||||
|
||||
# retry
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TOP_K: DEFAULT_TOP_K,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_TYPICAL_P: DEFAULT_MIN_P,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
|
||||
def test_schema_text_gen_webui_options_preserved(hass: HomeAssistant):
|
||||
overrides = {
|
||||
CONF_REQUEST_TIMEOUT: 123,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET: "custom-preset",
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_CONTEXT_LENGTH: 2048,
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
schema = _schema(hass, BACKEND_TYPE_TEXT_GEN_WEBUI, overrides)
|
||||
|
||||
expected = {CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REQUEST_TIMEOUT, CONF_CONTEXT_LENGTH}
|
||||
assert expected.issubset({getattr(k, "schema", None) for k in schema})
|
||||
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 123
|
||||
assert _get_suggested(schema, CONF_TEXT_GEN_WEBUI_PRESET) == "custom-preset"
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 2048
|
||||
|
||||
|
||||
def test_schema_generic_openai_options_preserved(hass: HomeAssistant):
|
||||
overrides = {CONF_TOP_P: 0.25, CONF_REQUEST_TIMEOUT: 321}
|
||||
|
||||
schema = _schema(hass, BACKEND_TYPE_GENERIC_OPENAI, overrides)
|
||||
|
||||
assert {CONF_TOP_P, CONF_REQUEST_TIMEOUT}.issubset({getattr(k, "schema", None) for k in schema})
|
||||
assert _get_default(schema, CONF_TOP_P) == DEFAULT_TOP_P
|
||||
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
|
||||
assert _get_suggested(schema, CONF_TOP_P) == 0.25
|
||||
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 321
|
||||
# Base prompt options still present
|
||||
prompt_default = _get_default(schema, CONF_PROMPT)
|
||||
assert prompt_default is not None and "You are 'Al'" in prompt_default
|
||||
assert _get_default(schema, CONF_NUM_IN_CONTEXT_EXAMPLES) == DEFAULT_NUM_IN_CONTEXT_EXAMPLES
|
||||
|
||||
|
||||
def test_schema_llama_cpp_server_includes_gbnf(hass: HomeAssistant):
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP_SERVER)
|
||||
keys = {getattr(k, "schema", None) for k in schema}
|
||||
|
||||
assert {CONF_MAX_TOKENS, CONF_TOP_K, CONF_GBNF_GRAMMAR_FILE}.issubset(keys)
|
||||
assert _get_default(schema, CONF_GBNF_GRAMMAR_FILE) == "output.gbnf"
|
||||
|
||||
|
||||
def test_schema_ollama_defaults_and_overrides(hass: HomeAssistant):
|
||||
overrides = {CONF_OLLAMA_KEEP_ALIVE_MIN: 5, CONF_CONTEXT_LENGTH: 1024, CONF_TOP_K: 7}
|
||||
schema = _schema(hass, BACKEND_TYPE_OLLAMA, overrides)
|
||||
|
||||
assert {CONF_MAX_TOKENS, CONF_CONTEXT_LENGTH, CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE}.issubset(
|
||||
{getattr(k, "schema", None) for k in schema}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert _get_default(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == DEFAULT_OLLAMA_KEEP_ALIVE_MIN
|
||||
assert _get_default(schema, CONF_OLLAMA_JSON_MODE) is DEFAULT_OLLAMA_JSON_MODE
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_default(schema, CONF_TOP_K) == DEFAULT_TOP_K
|
||||
assert _get_suggested(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == 5
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 1024
|
||||
assert _get_suggested(schema, CONF_TOP_K) == 7
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
|
||||
def test_schema_includes_llm_api_selector(monkeypatch, hass: HomeAssistant):
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.config_flow.llm.async_get_apis",
|
||||
lambda _hass: [type("API", (), {"id": "dummy", "name": "Dummy API", "tools": []})()],
|
||||
)
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP)
|
||||
|
||||
def test_validate_options_schema(hass: HomeAssistant):
|
||||
|
||||
universal_options = [
|
||||
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
]
|
||||
|
||||
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
|
||||
assert set(options_llama_hf.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
|
||||
assert set(options_llama_existing.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
|
||||
assert set(options_ollama.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
|
||||
assert set(options_text_gen_webui.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
|
||||
assert set(options_generic_openai.keys()) == set(universal_options + [
|
||||
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
|
||||
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
assert _get_default(schema, CONF_LLM_HASS_API) is None
|
||||
# Base prompt and thinking prefixes use defaults when not overridden
|
||||
prompt_default = _get_default(schema, CONF_PROMPT)
|
||||
assert prompt_default is not None and "You are 'Al'" in prompt_default
|
||||
assert _get_default(schema, CONF_THINKING_PREFIX) == DEFAULT_THINKING_PREFIX
|
||||
assert _get_default(schema, CONF_TOOL_CALL_PREFIX) == DEFAULT_TOOL_CALL_PREFIX
|
||||
|
||||
114
tests/llama_conversation/test_conversation_agent.py
Normal file
114
tests/llama_conversation/test_conversation_agent.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Tests for LocalLLMAgent async_process."""
|
||||
|
||||
import pytest
|
||||
from contextlib import contextmanager
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, SystemContent, AssistantContent
|
||||
from homeassistant.const import MATCH_ALL
|
||||
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, hass):
|
||||
self.hass = hass
|
||||
self.generated_prompts = []
|
||||
|
||||
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
|
||||
self.generated_prompts.append(prompt_template)
|
||||
return "rendered-system-prompt"
|
||||
|
||||
async def _async_generate(self, conv, agent_id, chat_log, entity_options):
|
||||
async def gen():
|
||||
yield AssistantContent(agent_id=agent_id, content="hello from llm")
|
||||
return gen()
|
||||
|
||||
|
||||
class DummySubentry:
|
||||
def __init__(self, subentry_id="sub1", title="Test Agent", chat_model="model"):
|
||||
self.subentry_id = subentry_id
|
||||
self.title = title
|
||||
self.subentry_type = DOMAIN
|
||||
self.data = {CONF_CHAT_MODEL: chat_model}
|
||||
|
||||
|
||||
class DummyEntry:
|
||||
def __init__(self, entry_id="entry1", options=None, subentry=None, runtime_data=None):
|
||||
self.entry_id = entry_id
|
||||
self.options = options or {}
|
||||
self.subentries = {subentry.subentry_id: subentry}
|
||||
self.runtime_data = runtime_data
|
||||
|
||||
def add_update_listener(self, _cb):
|
||||
return lambda: None
|
||||
|
||||
|
||||
class FakeChatLog:
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
self.llm_api = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeChatSession:
|
||||
def __enter__(self):
|
||||
return {}
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_process_generates_response(monkeypatch, hass):
|
||||
client = DummyClient(hass)
|
||||
subentry = DummySubentry()
|
||||
entry = DummyEntry(subentry=subentry, runtime_data=client)
|
||||
|
||||
# Make entry discoverable through hass data as LocalLLMEntity expects.
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
@contextmanager
|
||||
def fake_chat_session(_hass, _conversation_id):
|
||||
yield FakeChatSession()
|
||||
|
||||
@contextmanager
|
||||
def fake_chat_log(_hass, _session, _user_input):
|
||||
yield FakeChatLog()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.conversation.chat_session.async_get_chat_session",
|
||||
fake_chat_session,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.conversation.conversation.async_get_chat_log",
|
||||
fake_chat_log,
|
||||
)
|
||||
|
||||
agent = LocalLLMAgent(hass, entry, subentry, client)
|
||||
|
||||
result = await agent.async_process(
|
||||
ConversationInput(
|
||||
text="turn on the lights",
|
||||
context=None,
|
||||
conversation_id="conv-id",
|
||||
device_id=None,
|
||||
language="en",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.response.speech["plain"]["speech"] == "hello from llm"
|
||||
# System prompt should be rendered once when message history is empty.
|
||||
assert client.generated_prompts == [DEFAULT_PROMPT]
|
||||
assert agent.supported_languages == MATCH_ALL
|
||||
162
tests/llama_conversation/test_entity.py
Normal file
162
tests/llama_conversation/test_entity.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Tests for LocalLLMClient helpers in entity.py."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import pytest
|
||||
from json import JSONDecodeError
|
||||
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOOL_CALL_SUFFIX,
|
||||
DEFAULT_THINKING_PREFIX,
|
||||
DEFAULT_THINKING_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
class DummyLocalClient(LocalLLMClient):
|
||||
@staticmethod
|
||||
def get_name(_client_options):
|
||||
return "dummy"
|
||||
|
||||
|
||||
class DummyLLMApi:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(hass):
|
||||
# Disable ICL loading during tests to avoid filesystem access.
|
||||
return DummyLocalClient(hass, {CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_parses_tool_call(client):
|
||||
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":0.5,"to_say":" acknowledged"}}'
|
||||
completion = (
|
||||
f"{DEFAULT_THINKING_PREFIX}internal{DEFAULT_THINKING_SUFFIX}"
|
||||
f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
)
|
||||
|
||||
result = await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, completion)
|
||||
|
||||
assert result.response.strip().startswith("hello")
|
||||
assert "acknowledged" in result.response
|
||||
assert result.tool_calls
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.tool_name == "light.turn_on"
|
||||
assert tool_call.tool_args["brightness"] == 127
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_ignores_tools_without_llm_api(client):
|
||||
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":1}}'
|
||||
completion = f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
|
||||
result = await client._async_parse_completion(None, "agent-id", {}, completion)
|
||||
|
||||
assert result.tool_calls == []
|
||||
assert result.response.strip() == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_malformed_tool_raises(client):
|
||||
bad_tool = f"{DEFAULT_TOOL_CALL_PREFIX}{{not-json{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
|
||||
with pytest.raises(JSONDecodeError):
|
||||
await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, bad_tool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_handles_streamed_tool_call(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (
|
||||
None,
|
||||
[
|
||||
{
|
||||
"function": {
|
||||
"name": "light.turn_on",
|
||||
"arguments": {"brightness": 0.25, "to_say": " ok"},
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
stream = client._async_stream_parse_completion(
|
||||
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
|
||||
)
|
||||
|
||||
results = [chunk async for chunk in stream]
|
||||
|
||||
assert results[0].response == "Hi"
|
||||
assert results[1].response.strip() == "ok"
|
||||
assert results[1].tool_calls[0].tool_args["brightness"] == 63
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_malformed_tool_raises(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (None, ["{not-json"])
|
||||
|
||||
with pytest.raises(JSONDecodeError):
|
||||
async for _chunk in client._async_stream_parse_completion(
|
||||
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_ignores_tools_without_llm_api(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (None, ["{}"])
|
||||
|
||||
results = [chunk async for chunk in client._async_stream_parse_completion(
|
||||
None, "agent-id", {}, anext_token=token_generator()
|
||||
)]
|
||||
|
||||
assert results[0].response == "Hi"
|
||||
assert results[1].tool_calls is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_get_exposed_entities_respects_exposure(monkeypatch, client, hass):
|
||||
hass.states.async_set("light.exposed", "on", {"friendly_name": "Lamp"})
|
||||
hass.states.async_set("switch.hidden", "off", {"friendly_name": "Hidden"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.entity.async_should_expose",
|
||||
lambda _hass, _domain, entity_id: not entity_id.endswith("hidden"),
|
||||
)
|
||||
|
||||
exposed = client._async_get_exposed_entities()
|
||||
|
||||
assert "light.exposed" in exposed
|
||||
assert "switch.hidden" not in exposed
|
||||
assert exposed["light.exposed"]["friendly_name"] == "Lamp"
|
||||
assert exposed["light.exposed"]["state"] == "on"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_system_prompt_renders(monkeypatch, client, hass):
|
||||
hass.states.async_set("light.kitchen", "on", {"friendly_name": "Kitchen"})
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.entity.async_should_expose",
|
||||
lambda _hass, _domain, _entity_id: True,
|
||||
)
|
||||
|
||||
rendered = client._generate_system_prompt(
|
||||
"Devices:\n{{ formatted_devices }}",
|
||||
llm_api=None,
|
||||
entity_options={CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: []},
|
||||
)
|
||||
if inspect.iscoroutine(rendered):
|
||||
rendered = await rendered
|
||||
|
||||
assert isinstance(rendered, str)
|
||||
assert "light.kitchen" in rendered
|
||||
159
tests/llama_conversation/test_migrations.py
Normal file
159
tests/llama_conversation/test_migrations.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Regression tests for config entry migration in __init__.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from pytest_homeassistant_custom_component.common import MockConfigEntry
|
||||
|
||||
from custom_components.llama_conversation import async_migrate_entry
|
||||
from custom_components.llama_conversation.const import (
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_PROMPT,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v1_is_rejected(hass):
|
||||
entry = MockConfigEntry(domain=DOMAIN, data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP}, version=1)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v2_creates_subentry_and_updates_entry(monkeypatch, hass):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
title="llama 'Test Agent' entry",
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
|
||||
options={
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
CONF_PROMPT: "hello",
|
||||
CONF_REQUEST_TIMEOUT: 90,
|
||||
CONF_CHAT_MODEL: "model-x",
|
||||
CONF_CONTEXT_LENGTH: 1024,
|
||||
},
|
||||
version=2,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
added_subentries = []
|
||||
update_calls = []
|
||||
|
||||
def fake_add_subentry(cfg_entry, subentry):
|
||||
added_subentries.append((cfg_entry, subentry))
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
update_calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(hass.config_entries, "async_add_subentry", fake_add_subentry)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
assert added_subentries, "Subentry should be added"
|
||||
subentry = added_subentries[0][1]
|
||||
assert isinstance(subentry, ConfigSubentry)
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data[CONF_CHAT_MODEL] == "model-x"
|
||||
# Entry should be updated to version 3 with data/options separated
|
||||
assert any(call.get("version") == 3 for call in update_calls)
|
||||
last_options = [c["options"] for c in update_calls if "options" in c][-1]
|
||||
assert last_options[CONF_HOST] == "localhost"
|
||||
assert CONF_PROMPT not in last_options # moved to subentry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v3_minor0_downloads_model(monkeypatch, hass):
|
||||
sub_data = {
|
||||
CONF_CHAT_MODEL: "model-a",
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION: "Q4_K_M",
|
||||
CONF_REQUEST_TIMEOUT: 30,
|
||||
}
|
||||
subentry = ConfigSubentry(data=sub_data, subentry_type="conversation", title="sub", unique_id=None)
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP},
|
||||
options={},
|
||||
version=3,
|
||||
minor_version=0,
|
||||
)
|
||||
entry.subentries = {"sub": subentry}
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
updated_subentries = []
|
||||
update_calls = []
|
||||
|
||||
def fake_update_subentry(cfg_entry, old_sub, *, data=None, **_kwargs):
|
||||
updated_subentries.append((cfg_entry, old_sub, data))
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
update_calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.download_model_from_hf", lambda *_args, **_kw: "file.gguf"
|
||||
)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_subentry", fake_update_subentry)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
assert updated_subentries, "Subentry should be updated with downloaded file"
|
||||
new_data = updated_subentries[0][2]
|
||||
assert new_data[CONF_DOWNLOADED_MODEL_FILE] == "file.gguf"
|
||||
assert any(call.get("minor_version") == 1 for call in update_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_value,expected_list",
|
||||
[("api-1", ["api-1"]), (None, [])],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v3_minor1_converts_api_to_list(monkeypatch, hass, api_value, expected_list):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
|
||||
options={CONF_LLM_HASS_API: api_value},
|
||||
version=3,
|
||||
minor_version=1,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if "options" in kwargs:
|
||||
cfg_entry._options = kwargs["options"] # type: ignore[attr-defined]
|
||||
if "minor_version" in kwargs:
|
||||
cfg_entry._minor_version = kwargs["minor_version"] # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
options_calls = [c for c in calls if "options" in c]
|
||||
assert options_calls, "async_update_entry should be called with options"
|
||||
assert options_calls[-1]["options"][CONF_LLM_HASS_API] == expected_list
|
||||
|
||||
minor_calls = [c for c in calls if c.get("minor_version")]
|
||||
assert minor_calls and minor_calls[-1]["minor_version"] == 2
|
||||
Reference in New Issue
Block a user