rewrite tests from scratch

This commit is contained in:
Alex O'Connell
2025-12-14 01:07:23 -05:00
parent c8a5b30e5b
commit 6010bdf26c
8 changed files with 714 additions and 1084 deletions

View File

@@ -43,8 +43,6 @@ from .const import (
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
)
_LOGGER = logging.getLogger(__name__)

View File

@@ -6,4 +6,5 @@ home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component==0.13.260
# NOTE this must match the version of Home Assistant used for testing
pytest-homeassistant-custom-component==0.13.272

View File

@@ -1,763 +0,0 @@
import json
import logging
import pytest
import jinja2
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT_BASE,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
PROMPT_TEMPLATE_DESCRIPTIONS,
DEFAULT_OPTIONS,
)
from homeassistant.components.conversation import ConversationInput
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API
)
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
_LOGGER = logging.getLogger(__name__)
class WarnDict(dict):
def get(self, _key, _default=None):
if _key in self:
return self[_key]
_LOGGER.warning(f"attempting to get unset dictionary key {_key}")
return _default
class MockConfigEntry:
def __init__(self, entry_id='test_entry_id', data={}, options={}):
self.entry_id = entry_id
self.data = WarnDict(data)
# Use a mutable dict for options in tests
self.options = WarnDict(dict(options))
@pytest.fixture
def config_entry():
yield MockConfigEntry(
data={
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_BACKEND_TYPE: DEFAULT_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE: "/config/models/some-model.q4_k_m.gguf",
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_OPENAI_API_KEY: "OpenAI-API-Key",
CONF_TEXT_GEN_WEBUI_ADMIN_KEY: "Text-Gen-Webui-Admin-Key"
},
options={
**DEFAULT_OPTIONS,
CONF_LLM_HASS_API: LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
}
)
@pytest.fixture
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
entry_mock.return_value = config_entry
llama_instance_mock = MagicMock()
llama_class_mock = MagicMock()
llama_class_mock.return_value = llama_instance_mock
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
install_llama_cpp_python_mock.return_value = True
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
# template_mock.side_affect = lambda template, _: jinja2.Template(template)
generate_mock = llama_instance_mock.generate
generate_mock.return_value = list(range(20))
detokenize_mock = llama_instance_mock.detokenize
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
})).encode()
call_tool_mock.return_value = {"result": "success"}
agent_obj = LlamaCppAgent(
hass,
config_entry
)
all_mocks = {
"llama_class": llama_class_mock,
"tokenize": llama_instance_mock.tokenize,
"generate": generate_mock,
}
yield agent_obj, all_mocks
# TODO: test base llama agent (ICL loading other languages)
async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent: LlamaCppAgent
all_mocks: dict[str, MagicMock]
local_llama_agent, all_mocks = local_llama_agent_fixture
# invoke the conversation agent
conversation_id = "test-conversation"
result = await local_llama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options then apply them
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
local_llama_agent._update_options()
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
# do another turn of the same conversation
result = await local_llama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
@pytest.fixture
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
get_clientsession.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = OllamaAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession.get,
"requests_post": get_clientsession.post
}
yield agent_obj, all_mocks
async def test_ollama_agent(ollama_agent_fixture):
ollama_agent: OllamaAPIAgent
all_mocks: dict[str, MagicMock]
ollama_agent, all_mocks = ollama_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/api/tags",
headers={ "Authorization": "Bearer OpenAI-API-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"created_at": "2023-11-09T21:07:55.186497Z",
"response": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"done": True,
"context": [1, 2, 3],
"total_duration": 4648158584,
"load_duration": 4071084,
"prompt_eval_count": 36,
"prompt_eval_duration": 439038000,
"eval_count": 180,
"eval_duration": 4196918000
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await ollama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/generate",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"prompt": ANY,
"raw": True
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
ollama_agent.entry.options[CONF_TOP_K] = 20
ollama_agent.entry.options[CONF_TOP_P] = 0.9
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
# do another turn of the same conversation
result = await ollama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/api/chat",
headers={ "Authorization": "Bearer OpenAI-API-Key" },
json={
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"stream": False,
"format": "json",
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
"options": {
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
"messages": ANY
},
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
response_mock = MagicMock()
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
get_clientsession_mock.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = TextGenerationWebuiAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
text_generation_webui_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
all_mocks["requests_get"].assert_called_once_with(
"http://localhost:5000/v1/internal/model/info",
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
)
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
# do another turn of the same conversation and use a preset
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"preset": "Some Preset",
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# change options
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
response_mock.json.return_value = {
"id": "chatcmpl-123",
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
"object": "chat.completions",
"created": 1677652288,
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# do another turn of the same conversation but the chat endpoint
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
{
"light.kitchen_light": { "state": "on" },
"light.office_lamp": { "state": "on" },
"switch.downstairs_hallway": { "state": "off" },
"fan.bedroom": { "state": "on" },
},
["light", "switch", "fan"]
)
call_tool_mock.return_value = {"result": "success"}
agent_obj = GenericOpenAIAPIAgent(
hass,
config_entry
)
all_mocks = {
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
async def test_generic_openai_agent(generic_openai_agent_fixture):
generic_openai_agent: TextGenerationWebuiAgent
all_mocks: dict[str, MagicMock]
generic_openai_agent, all_mocks = generic_openai_agent_fixture
response_mock = MagicMock()
response_mock.json.return_value = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
all_mocks["requests_post"].return_value = response_mock
# invoke the conversation agent
conversation_id = "test-conversation"
result = await generic_openai_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
response_mock.json.return_value = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
# do another turn of the same conversation but the chat endpoint
result = await generic_openai_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
)

View File

@@ -0,0 +1,105 @@
"""Lightweight smoke tests for backend helpers.
These avoid backend calls and only cover helper utilities to keep the suite green
while the integration evolves. No integration code is modified.
"""
import pytest
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from custom_components.llama_conversation.backends.llamacpp import snapshot_settings
from custom_components.llama_conversation.backends.ollama import OllamaAPIClient, _normalize_path
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_GBNF_GRAMMAR_FILE,
CONF_PROMPT_CACHING_ENABLED,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_PROMPT_CACHING_ENABLED,
CONF_GENERIC_OPENAI_PATH,
)
@pytest.fixture
async def hass_defaults(hass):
return hass
def test_snapshot_settings_defaults():
options = {CONF_CHAT_MODEL: "test-model"}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == DEFAULT_CONTEXT_LENGTH
assert snap[CONF_LLAMACPP_BATCH_SIZE] == DEFAULT_LLAMACPP_BATCH_SIZE
assert snap[CONF_LLAMACPP_THREAD_COUNT] == DEFAULT_LLAMACPP_THREAD_COUNT
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] == DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert snap[CONF_GBNF_GRAMMAR_FILE] == DEFAULT_GBNF_GRAMMAR_FILE
assert snap[CONF_PROMPT_CACHING_ENABLED] == DEFAULT_PROMPT_CACHING_ENABLED
def test_snapshot_settings_overrides():
options = {
CONF_CONTEXT_LENGTH: 4096,
CONF_LLAMACPP_BATCH_SIZE: 64,
CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_GBNF_GRAMMAR_FILE: "custom.gbnf",
CONF_PROMPT_CACHING_ENABLED: True,
}
snap = snapshot_settings(options)
assert snap[CONF_CONTEXT_LENGTH] == 4096
assert snap[CONF_LLAMACPP_BATCH_SIZE] == 64
assert snap[CONF_LLAMACPP_THREAD_COUNT] == 6
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == 3
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] is True
assert snap[CONF_GBNF_GRAMMAR_FILE] == "custom.gbnf"
assert snap[CONF_PROMPT_CACHING_ENABLED] is True
def test_ollama_keep_alive_formatting():
assert OllamaAPIClient._format_keep_alive("0") == 0
assert OllamaAPIClient._format_keep_alive("0.0") == 0
assert OllamaAPIClient._format_keep_alive(5) == "5m"
assert OllamaAPIClient._format_keep_alive("15") == "15m"
def test_generic_openai_name_and_path(hass_defaults):
client = GenericOpenAIAPIClient(
hass_defaults,
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_CHAT_MODEL: "demo",
},
)
name = client.get_name(
{
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
}
)
assert "Generic OpenAI" in name
assert "localhost" in name
def test_normalize_path_helper():
assert _normalize_path(None) == ""
assert _normalize_path("") == ""
assert _normalize_path("/v1/") == "/v1"
assert _normalize_path("v2") == "/v2"

View File

@@ -1,350 +1,204 @@
"""Config flow option schema tests to ensure options are wired per-backend."""
import pytest
from unittest.mock import patch, MagicMock
from homeassistant import config_entries, setup
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API,
)
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import llm
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_TOOL_FORMAT,
CONF_TOOL_MULTI_TURN_CHAT,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA,
DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_CONTEXT_LENGTH,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_GBNF_GRAMMAR_FILE,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_LLAMACPP_THREAD_COUNT,
CONF_MAX_TOKENS,
CONF_MIN_P,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_OLLAMA_JSON_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_PROMPT,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_REQUEST_TIMEOUT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_THINKING_PREFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_TEMPERATURE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
DOMAIN,
DEFAULT_PROMPT,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_THINKING_PREFIX,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
)
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
# result = await hass.config_entries.flow.async_init(
# DOMAIN, context={"source": config_entries.SOURCE_USER}
# )
# assert result["type"] == FlowResultType.FORM
# assert result["errors"] is None
# result2 = await hass.config_entries.flow.async_configure(
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
# )
# assert result2["type"] == FlowResultType.FORM
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
# result3 = await hass.config_entries.flow.async_configure(
# result2["flow_id"],
# TEST_DATA,
# )
# await hass.async_block_till_done()
# assert result3["type"] == "create_entry"
# assert result3["title"] == ""
# assert result3["data"] == {
# # ACCOUNT_ID: TEST_DATA["account_id"],
# # CONF_PASSWORD: TEST_DATA["password"],
# # CONNECTION_TYPE: CLOUD,
# }
# assert result3["options"] == {}
# assert len(mock_setup_entry.mock_calls) == 1
@pytest.fixture
def validate_connections_mock():
validate_mock = MagicMock()
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
yield validate_mock
@pytest.fixture
def mock_setup_entry():
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
yield mock_setup_entry
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
def _schema(hass: HomeAssistant, backend: str, options: dict | None = None):
return local_llama_config_option_schema(
hass=hass,
language="en",
options=options or {},
backend_type=backend,
subentry_type="conversation",
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
def _get_default(schema: dict, key_name: str):
for key in schema:
if getattr(key, "schema", None) == key_name:
default = getattr(key, "default", None)
return default() if callable(default) else default
raise AssertionError(f"Key {key_name} not found in schema")
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
def _get_suggested(schema: dict, key_name: str):
for key in schema:
if getattr(key, "schema", None) == key_name:
return (getattr(key, "description", {}) or {}).get("suggested_value")
raise AssertionError(f"Key {key_name} not found in schema")
def test_schema_llama_cpp_defaults_and_overrides(hass: HomeAssistant):
overrides = {
CONF_CONTEXT_LENGTH: 4096,
CONF_LLAMACPP_BATCH_SIZE: 8,
CONF_LLAMACPP_THREAD_COUNT: 6,
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
CONF_PROMPT_CACHING_INTERVAL: 15,
CONF_TOP_K: 12,
CONF_TOOL_CALL_PREFIX: "<tc>",
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
)
await hass.async_block_till_done()
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP, overrides)
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
expected_keys = {
CONF_MAX_TOKENS,
CONF_CONTEXT_LENGTH,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_GBNF_GRAMMAR_FILE,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
}
assert result4["options"] == options_dict
assert len(mock_setup_entry.mock_calls) == 1
assert expected_keys.issubset({getattr(k, "schema", None) for k in schema})
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_default(schema, CONF_LLAMACPP_BATCH_SIZE) == DEFAULT_LLAMACPP_BATCH_SIZE
assert _get_default(schema, CONF_LLAMACPP_THREAD_COUNT) == DEFAULT_LLAMACPP_THREAD_COUNT
assert _get_default(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
assert _get_default(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
assert _get_default(schema, CONF_PROMPT_CACHING_INTERVAL) == DEFAULT_PROMPT_CACHING_INTERVAL
# suggested values should reflect overrides
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 4096
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_SIZE) == 8
assert _get_suggested(schema, CONF_LLAMACPP_THREAD_COUNT) == 6
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == 3
assert _get_suggested(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is True
assert _get_suggested(schema, CONF_PROMPT_CACHING_INTERVAL) == 15
assert _get_suggested(schema, CONF_TOP_K) == 12
assert _get_suggested(schema, CONF_TOOL_CALL_PREFIX) == "<tc>"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
# simulate incorrect settings on first try
validate_connections_mock.side_effect = [
("failed_to_connect", Exception("ConnectionError"), []),
(None, None, [])
]
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert len(result3["errors"]) == 1
assert "base" in result3["errors"]
assert result3["step_id"] == "remote_model"
# retry
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TYPICAL_P: DEFAULT_MIN_P,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
def test_schema_text_gen_webui_options_preserved(hass: HomeAssistant):
overrides = {
CONF_REQUEST_TIMEOUT: 123,
CONF_TEXT_GEN_WEBUI_PRESET: "custom-preset",
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_CONTEXT_LENGTH: 2048,
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
schema = _schema(hass, BACKEND_TYPE_TEXT_GEN_WEBUI, overrides)
expected = {CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REQUEST_TIMEOUT, CONF_CONTEXT_LENGTH}
assert expected.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 123
assert _get_suggested(schema, CONF_TEXT_GEN_WEBUI_PRESET) == "custom-preset"
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 2048
def test_schema_generic_openai_options_preserved(hass: HomeAssistant):
overrides = {CONF_TOP_P: 0.25, CONF_REQUEST_TIMEOUT: 321}
schema = _schema(hass, BACKEND_TYPE_GENERIC_OPENAI, overrides)
assert {CONF_TOP_P, CONF_REQUEST_TIMEOUT}.issubset({getattr(k, "schema", None) for k in schema})
assert _get_default(schema, CONF_TOP_P) == DEFAULT_TOP_P
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
assert _get_suggested(schema, CONF_TOP_P) == 0.25
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 321
# Base prompt options still present
prompt_default = _get_default(schema, CONF_PROMPT)
assert prompt_default is not None and "You are 'Al'" in prompt_default
assert _get_default(schema, CONF_NUM_IN_CONTEXT_EXAMPLES) == DEFAULT_NUM_IN_CONTEXT_EXAMPLES
def test_schema_llama_cpp_server_includes_gbnf(hass: HomeAssistant):
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP_SERVER)
keys = {getattr(k, "schema", None) for k in schema}
assert {CONF_MAX_TOKENS, CONF_TOP_K, CONF_GBNF_GRAMMAR_FILE}.issubset(keys)
assert _get_default(schema, CONF_GBNF_GRAMMAR_FILE) == "output.gbnf"
def test_schema_ollama_defaults_and_overrides(hass: HomeAssistant):
overrides = {CONF_OLLAMA_KEEP_ALIVE_MIN: 5, CONF_CONTEXT_LENGTH: 1024, CONF_TOP_K: 7}
schema = _schema(hass, BACKEND_TYPE_OLLAMA, overrides)
assert {CONF_MAX_TOKENS, CONF_CONTEXT_LENGTH, CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE}.issubset(
{getattr(k, "schema", None) for k in schema}
)
await hass.async_block_till_done()
assert _get_default(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == DEFAULT_OLLAMA_KEEP_ALIVE_MIN
assert _get_default(schema, CONF_OLLAMA_JSON_MODE) is DEFAULT_OLLAMA_JSON_MODE
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
assert _get_default(schema, CONF_TOP_K) == DEFAULT_TOP_K
assert _get_suggested(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == 5
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 1024
assert _get_suggested(schema, CONF_TOP_K) == 7
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
}
assert result4["options"] == options_dict
mock_setup_entry.assert_called_once()
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
def test_schema_includes_llm_api_selector(monkeypatch, hass: HomeAssistant):
monkeypatch.setattr(
"custom_components.llama_conversation.config_flow.llm.async_get_apis",
lambda _hass: [type("API", (), {"id": "dummy", "name": "Dummy API", "tools": []})()],
)
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP)
def test_validate_options_schema(hass: HomeAssistant):
universal_options = [
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
assert set(options_generic_openai.keys()) == set(universal_options + [
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
assert _get_default(schema, CONF_LLM_HASS_API) is None
# Base prompt and thinking prefixes use defaults when not overridden
prompt_default = _get_default(schema, CONF_PROMPT)
assert prompt_default is not None and "You are 'Al'" in prompt_default
assert _get_default(schema, CONF_THINKING_PREFIX) == DEFAULT_THINKING_PREFIX
assert _get_default(schema, CONF_TOOL_CALL_PREFIX) == DEFAULT_TOOL_CALL_PREFIX

View File

@@ -0,0 +1,114 @@
"""Tests for LocalLLMAgent async_process."""
import pytest
from contextlib import contextmanager
from homeassistant.components.conversation import ConversationInput, SystemContent, AssistantContent
from homeassistant.const import MATCH_ALL
from custom_components.llama_conversation.conversation import LocalLLMAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_PROMPT,
DOMAIN,
)
class DummyClient:
def __init__(self, hass):
self.hass = hass
self.generated_prompts = []
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
self.generated_prompts.append(prompt_template)
return "rendered-system-prompt"
async def _async_generate(self, conv, agent_id, chat_log, entity_options):
async def gen():
yield AssistantContent(agent_id=agent_id, content="hello from llm")
return gen()
class DummySubentry:
def __init__(self, subentry_id="sub1", title="Test Agent", chat_model="model"):
self.subentry_id = subentry_id
self.title = title
self.subentry_type = DOMAIN
self.data = {CONF_CHAT_MODEL: chat_model}
class DummyEntry:
def __init__(self, entry_id="entry1", options=None, subentry=None, runtime_data=None):
self.entry_id = entry_id
self.options = options or {}
self.subentries = {subentry.subentry_id: subentry}
self.runtime_data = runtime_data
def add_update_listener(self, _cb):
return lambda: None
class FakeChatLog:
def __init__(self):
self.content = []
self.llm_api = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeChatSession:
def __enter__(self):
return {}
def __exit__(self, exc_type, exc, tb):
return False
@pytest.mark.asyncio
async def test_async_process_generates_response(monkeypatch, hass):
client = DummyClient(hass)
subentry = DummySubentry()
entry = DummyEntry(subentry=subentry, runtime_data=client)
# Make entry discoverable through hass data as LocalLLMEntity expects.
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
@contextmanager
def fake_chat_session(_hass, _conversation_id):
yield FakeChatSession()
@contextmanager
def fake_chat_log(_hass, _session, _user_input):
yield FakeChatLog()
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.chat_session.async_get_chat_session",
fake_chat_session,
)
monkeypatch.setattr(
"custom_components.llama_conversation.conversation.conversation.async_get_chat_log",
fake_chat_log,
)
agent = LocalLLMAgent(hass, entry, subentry, client)
result = await agent.async_process(
ConversationInput(
text="turn on the lights",
context=None,
conversation_id="conv-id",
device_id=None,
language="en",
agent_id="agent-1",
)
)
assert result.response.speech["plain"]["speech"] == "hello from llm"
# System prompt should be rendered once when message history is empty.
assert client.generated_prompts == [DEFAULT_PROMPT]
assert agent.supported_languages == MATCH_ALL

View File

@@ -0,0 +1,162 @@
"""Tests for LocalLLMClient helpers in entity.py."""
import inspect
import json
import pytest
from json import JSONDecodeError
from custom_components.llama_conversation.entity import LocalLLMClient
from custom_components.llama_conversation.const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_THINKING_PREFIX,
DEFAULT_THINKING_SUFFIX,
)
class DummyLocalClient(LocalLLMClient):
@staticmethod
def get_name(_client_options):
return "dummy"
class DummyLLMApi:
def __init__(self):
self.tools = []
@pytest.fixture
def client(hass):
# Disable ICL loading during tests to avoid filesystem access.
return DummyLocalClient(hass, {CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False})
@pytest.mark.asyncio
async def test_async_parse_completion_parses_tool_call(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":0.5,"to_say":" acknowledged"}}'
completion = (
f"{DEFAULT_THINKING_PREFIX}internal{DEFAULT_THINKING_SUFFIX}"
f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
)
result = await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, completion)
assert result.response.strip().startswith("hello")
assert "acknowledged" in result.response
assert result.tool_calls
tool_call = result.tool_calls[0]
assert tool_call.tool_name == "light.turn_on"
assert tool_call.tool_args["brightness"] == 127
@pytest.mark.asyncio
async def test_async_parse_completion_ignores_tools_without_llm_api(client):
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":1}}'
completion = f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
result = await client._async_parse_completion(None, "agent-id", {}, completion)
assert result.tool_calls == []
assert result.response.strip() == "hello"
@pytest.mark.asyncio
async def test_async_parse_completion_malformed_tool_raises(client):
bad_tool = f"{DEFAULT_TOOL_CALL_PREFIX}{{not-json{DEFAULT_TOOL_CALL_SUFFIX}"
with pytest.raises(JSONDecodeError):
await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, bad_tool)
@pytest.mark.asyncio
async def test_async_stream_parse_completion_handles_streamed_tool_call(client):
async def token_generator():
yield ("Hi", None)
yield (
None,
[
{
"function": {
"name": "light.turn_on",
"arguments": {"brightness": 0.25, "to_say": " ok"},
}
}
],
)
stream = client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
)
results = [chunk async for chunk in stream]
assert results[0].response == "Hi"
assert results[1].response.strip() == "ok"
assert results[1].tool_calls[0].tool_args["brightness"] == 63
@pytest.mark.asyncio
async def test_async_stream_parse_completion_malformed_tool_raises(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{not-json"])
with pytest.raises(JSONDecodeError):
async for _chunk in client._async_stream_parse_completion(
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
):
pass
@pytest.mark.asyncio
async def test_async_stream_parse_completion_ignores_tools_without_llm_api(client):
async def token_generator():
yield ("Hi", None)
yield (None, ["{}"])
results = [chunk async for chunk in client._async_stream_parse_completion(
None, "agent-id", {}, anext_token=token_generator()
)]
assert results[0].response == "Hi"
assert results[1].tool_calls is None
@pytest.mark.asyncio
async def test_async_get_exposed_entities_respects_exposure(monkeypatch, client, hass):
hass.states.async_set("light.exposed", "on", {"friendly_name": "Lamp"})
hass.states.async_set("switch.hidden", "off", {"friendly_name": "Hidden"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, entity_id: not entity_id.endswith("hidden"),
)
exposed = client._async_get_exposed_entities()
assert "light.exposed" in exposed
assert "switch.hidden" not in exposed
assert exposed["light.exposed"]["friendly_name"] == "Lamp"
assert exposed["light.exposed"]["state"] == "on"
@pytest.mark.asyncio
async def test_generate_system_prompt_renders(monkeypatch, client, hass):
hass.states.async_set("light.kitchen", "on", {"friendly_name": "Kitchen"})
monkeypatch.setattr(
"custom_components.llama_conversation.entity.async_should_expose",
lambda _hass, _domain, _entity_id: True,
)
rendered = client._generate_system_prompt(
"Devices:\n{{ formatted_devices }}",
llm_api=None,
entity_options={CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: []},
)
if inspect.iscoroutine(rendered):
rendered = await rendered
assert isinstance(rendered, str)
assert "light.kitchen" in rendered

View File

@@ -0,0 +1,159 @@
"""Regression tests for config entry migration in __init__.py."""
import pytest
from homeassistant.const import CONF_LLM_HASS_API, CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.config_entries import ConfigSubentry
from pytest_homeassistant_custom_component.common import MockConfigEntry
from custom_components.llama_conversation import async_migrate_entry
from custom_components.llama_conversation.const import (
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_SERVER,
CONF_BACKEND_TYPE,
CONF_CHAT_MODEL,
CONF_CONTEXT_LENGTH,
CONF_DOWNLOADED_MODEL_FILE,
CONF_DOWNLOADED_MODEL_QUANTIZATION,
CONF_GENERIC_OPENAI_PATH,
CONF_PROMPT,
CONF_REQUEST_TIMEOUT,
DOMAIN,
)
@pytest.mark.asyncio
async def test_migrate_v1_is_rejected(hass):
entry = MockConfigEntry(domain=DOMAIN, data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP}, version=1)
entry.add_to_hass(hass)
result = await async_migrate_entry(hass, entry)
assert result is False
@pytest.mark.asyncio
async def test_migrate_v2_creates_subentry_and_updates_entry(monkeypatch, hass):
entry = MockConfigEntry(
domain=DOMAIN,
title="llama 'Test Agent' entry",
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={
CONF_HOST: "localhost",
CONF_PORT: "8080",
CONF_SSL: False,
CONF_GENERIC_OPENAI_PATH: "v1",
CONF_PROMPT: "hello",
CONF_REQUEST_TIMEOUT: 90,
CONF_CHAT_MODEL: "model-x",
CONF_CONTEXT_LENGTH: 1024,
},
version=2,
)
entry.add_to_hass(hass)
added_subentries = []
update_calls = []
def fake_add_subentry(cfg_entry, subentry):
added_subentries.append((cfg_entry, subentry))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(hass.config_entries, "async_add_subentry", fake_add_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert added_subentries, "Subentry should be added"
subentry = added_subentries[0][1]
assert isinstance(subentry, ConfigSubentry)
assert subentry.subentry_type == "conversation"
assert subentry.data[CONF_CHAT_MODEL] == "model-x"
# Entry should be updated to version 3 with data/options separated
assert any(call.get("version") == 3 for call in update_calls)
last_options = [c["options"] for c in update_calls if "options" in c][-1]
assert last_options[CONF_HOST] == "localhost"
assert CONF_PROMPT not in last_options # moved to subentry
@pytest.mark.asyncio
async def test_migrate_v3_minor0_downloads_model(monkeypatch, hass):
sub_data = {
CONF_CHAT_MODEL: "model-a",
CONF_DOWNLOADED_MODEL_QUANTIZATION: "Q4_K_M",
CONF_REQUEST_TIMEOUT: 30,
}
subentry = ConfigSubentry(data=sub_data, subentry_type="conversation", title="sub", unique_id=None)
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP},
options={},
version=3,
minor_version=0,
)
entry.subentries = {"sub": subentry}
entry.add_to_hass(hass)
updated_subentries = []
update_calls = []
def fake_update_subentry(cfg_entry, old_sub, *, data=None, **_kwargs):
updated_subentries.append((cfg_entry, old_sub, data))
def fake_update_entry(cfg_entry, **kwargs):
update_calls.append(kwargs)
monkeypatch.setattr(
"custom_components.llama_conversation.download_model_from_hf", lambda *_args, **_kw: "file.gguf"
)
monkeypatch.setattr(hass.config_entries, "async_update_subentry", fake_update_subentry)
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
assert updated_subentries, "Subentry should be updated with downloaded file"
new_data = updated_subentries[0][2]
assert new_data[CONF_DOWNLOADED_MODEL_FILE] == "file.gguf"
assert any(call.get("minor_version") == 1 for call in update_calls)
@pytest.mark.parametrize(
"api_value,expected_list",
[("api-1", ["api-1"]), (None, [])],
)
@pytest.mark.asyncio
async def test_migrate_v3_minor1_converts_api_to_list(monkeypatch, hass, api_value, expected_list):
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
options={CONF_LLM_HASS_API: api_value},
version=3,
minor_version=1,
)
entry.add_to_hass(hass)
calls = []
def fake_update_entry(cfg_entry, **kwargs):
calls.append(kwargs)
if "options" in kwargs:
cfg_entry._options = kwargs["options"] # type: ignore[attr-defined]
if "minor_version" in kwargs:
cfg_entry._minor_version = kwargs["minor_version"] # type: ignore[attr-defined]
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
result = await async_migrate_entry(hass, entry)
assert result is True
options_calls = [c for c in calls if "options" in c]
assert options_calls, "async_update_entry should be called with options"
assert options_calls[-1]["options"][CONF_LLM_HASS_API] == expected_list
minor_calls = [c for c in calls if c.get("minor_version")]
assert minor_calls and minor_calls[-1]["minor_version"] == 2