mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
rewrite tests from scratch
This commit is contained in:
@@ -43,8 +43,6 @@ from .const import (
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOOL_CALL_SUFFIX,
|
||||
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,4 +6,5 @@ home-assistant-intents
|
||||
# testing requirements
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component==0.13.260
|
||||
# NOTE this must match the version of Home Assistant used for testing
|
||||
pytest-homeassistant-custom-component==0.13.272
|
||||
|
||||
@@ -1,763 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
|
||||
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT_BASE,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
DOMAIN,
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
DEFAULT_OPTIONS,
|
||||
)
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
CONF_LLM_HASS_API
|
||||
)
|
||||
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class WarnDict(dict):
|
||||
def get(self, _key, _default=None):
|
||||
if _key in self:
|
||||
return self[_key]
|
||||
|
||||
_LOGGER.warning(f"attempting to get unset dictionary key {_key}")
|
||||
|
||||
return _default
|
||||
|
||||
class MockConfigEntry:
|
||||
def __init__(self, entry_id='test_entry_id', data={}, options={}):
|
||||
self.entry_id = entry_id
|
||||
self.data = WarnDict(data)
|
||||
# Use a mutable dict for options in tests
|
||||
self.options = WarnDict(dict(options))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry():
|
||||
yield MockConfigEntry(
|
||||
data={
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
CONF_BACKEND_TYPE: DEFAULT_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE: "/config/models/some-model.q4_k_m.gguf",
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_OPENAI_API_KEY: "OpenAI-API-Key",
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY: "Text-Gen-Webui-Admin-Key"
|
||||
},
|
||||
options={
|
||||
**DEFAULT_OPTIONS,
|
||||
CONF_LLM_HASS_API: LLM_API_ASSIST,
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE,
|
||||
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
|
||||
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
|
||||
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
llama_instance_mock = MagicMock()
|
||||
llama_class_mock = MagicMock()
|
||||
llama_class_mock.return_value = llama_instance_mock
|
||||
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
|
||||
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
|
||||
install_llama_cpp_python_mock.return_value = True
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
# template_mock.side_affect = lambda template, _: jinja2.Template(template)
|
||||
generate_mock = llama_instance_mock.generate
|
||||
generate_mock.return_value = list(range(20))
|
||||
|
||||
detokenize_mock = llama_instance_mock.detokenize
|
||||
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
})).encode()
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = LlamaCppAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"llama_class": llama_class_mock,
|
||||
"tokenize": llama_instance_mock.tokenize,
|
||||
"generate": generate_mock,
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
# TODO: test base llama agent (ICL loading other languages)
|
||||
|
||||
async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
local_llama_agent: LlamaCppAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
local_llama_agent, all_mocks = local_llama_agent_fixture
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
|
||||
min_p=local_llama_agent.entry.options[CONF_MIN_P],
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options then apply them
|
||||
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
|
||||
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
|
||||
local_llama_agent.entry.options[CONF_TOP_K] = 20
|
||||
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
|
||||
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
|
||||
|
||||
local_llama_agent._update_options()
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
|
||||
min_p=local_llama_agent.entry.options[CONF_MIN_P],
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
|
||||
get_clientsession.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = OllamaAPIAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession.get,
|
||||
"requests_post": get_clientsession.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_ollama_agent(ollama_agent_fixture):
|
||||
|
||||
ollama_agent: OllamaAPIAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
ollama_agent, all_mocks = ollama_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/api/tags",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"done": True,
|
||||
"context": [1, 2, 3],
|
||||
"total_duration": 4648158584,
|
||||
"load_duration": 4071084,
|
||||
"prompt_eval_count": 36,
|
||||
"prompt_eval_duration": 439038000,
|
||||
"eval_count": 180,
|
||||
"eval_duration": 4196918000
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/generate",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"prompt": ANY,
|
||||
"raw": True
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
|
||||
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
|
||||
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
ollama_agent.entry.options[CONF_TOP_K] = 20
|
||||
ollama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/chat",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"messages": ANY
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
|
||||
get_clientsession_mock.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = TextGenerationWebuiAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
text_generation_webui_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/internal/model/info",
|
||||
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"preset": "Some Preset",
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# change options
|
||||
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
|
||||
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
"object": "chat.completions",
|
||||
"created": 1677652288,
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
agent_obj = GenericOpenAIAPIAgent(
|
||||
hass,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
|
||||
generic_openai_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
generic_openai_agent, all_mocks = generic_openai_agent_fixture
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am saying something!\n" + json.dumps({
|
||||
"name": "HassTurnOn",
|
||||
"arguments": {
|
||||
"name": "light.kitchen_light"
|
||||
}
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
105
tests/llama_conversation/test_basic.py
Normal file
105
tests/llama_conversation/test_basic.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Lightweight smoke tests for backend helpers.
|
||||
|
||||
These avoid backend calls and only cover helper utilities to keep the suite green
|
||||
while the integration evolves. No integration code is modified.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
|
||||
from custom_components.llama_conversation.backends.llamacpp import snapshot_settings
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIClient, _normalize_path
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def hass_defaults(hass):
|
||||
return hass
|
||||
|
||||
|
||||
def test_snapshot_settings_defaults():
|
||||
options = {CONF_CHAT_MODEL: "test-model"}
|
||||
snap = snapshot_settings(options)
|
||||
assert snap[CONF_CONTEXT_LENGTH] == DEFAULT_CONTEXT_LENGTH
|
||||
assert snap[CONF_LLAMACPP_BATCH_SIZE] == DEFAULT_LLAMACPP_BATCH_SIZE
|
||||
assert snap[CONF_LLAMACPP_THREAD_COUNT] == DEFAULT_LLAMACPP_THREAD_COUNT
|
||||
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
|
||||
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] == DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
|
||||
assert snap[CONF_GBNF_GRAMMAR_FILE] == DEFAULT_GBNF_GRAMMAR_FILE
|
||||
assert snap[CONF_PROMPT_CACHING_ENABLED] == DEFAULT_PROMPT_CACHING_ENABLED
|
||||
|
||||
|
||||
def test_snapshot_settings_overrides():
|
||||
options = {
|
||||
CONF_CONTEXT_LENGTH: 4096,
|
||||
CONF_LLAMACPP_BATCH_SIZE: 64,
|
||||
CONF_LLAMACPP_THREAD_COUNT: 6,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
|
||||
CONF_GBNF_GRAMMAR_FILE: "custom.gbnf",
|
||||
CONF_PROMPT_CACHING_ENABLED: True,
|
||||
}
|
||||
snap = snapshot_settings(options)
|
||||
assert snap[CONF_CONTEXT_LENGTH] == 4096
|
||||
assert snap[CONF_LLAMACPP_BATCH_SIZE] == 64
|
||||
assert snap[CONF_LLAMACPP_THREAD_COUNT] == 6
|
||||
assert snap[CONF_LLAMACPP_BATCH_THREAD_COUNT] == 3
|
||||
assert snap[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] is True
|
||||
assert snap[CONF_GBNF_GRAMMAR_FILE] == "custom.gbnf"
|
||||
assert snap[CONF_PROMPT_CACHING_ENABLED] is True
|
||||
|
||||
|
||||
def test_ollama_keep_alive_formatting():
|
||||
assert OllamaAPIClient._format_keep_alive("0") == 0
|
||||
assert OllamaAPIClient._format_keep_alive("0.0") == 0
|
||||
assert OllamaAPIClient._format_keep_alive(5) == "5m"
|
||||
assert OllamaAPIClient._format_keep_alive("15") == "15m"
|
||||
|
||||
|
||||
def test_generic_openai_name_and_path(hass_defaults):
|
||||
client = GenericOpenAIAPIClient(
|
||||
hass_defaults,
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
CONF_CHAT_MODEL: "demo",
|
||||
},
|
||||
)
|
||||
name = client.get_name(
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
}
|
||||
)
|
||||
assert "Generic OpenAI" in name
|
||||
assert "localhost" in name
|
||||
|
||||
|
||||
def test_normalize_path_helper():
|
||||
assert _normalize_path(None) == ""
|
||||
assert _normalize_path("") == ""
|
||||
assert _normalize_path("/v1/") == "/v1"
|
||||
assert _normalize_path("v2") == "/v2"
|
||||
@@ -1,350 +1,204 @@
|
||||
"""Config flow option schema tests to ensure options are wired per-backend."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from homeassistant import config_entries, setup
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
CONF_LLM_HASS_API,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
|
||||
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_TOOL_FORMAT,
|
||||
CONF_TOOL_MULTI_TURN_CHAT,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_MIN_P,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_PROMPT,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_THINKING_PREFIX,
|
||||
CONF_TOOL_CALL_PREFIX,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_TEMPERATURE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
DOMAIN,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_THINKING_PREFIX,
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
)
|
||||
|
||||
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
|
||||
# result = await hass.config_entries.flow.async_init(
|
||||
# DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
# )
|
||||
# assert result["type"] == FlowResultType.FORM
|
||||
# assert result["errors"] is None
|
||||
|
||||
# result2 = await hass.config_entries.flow.async_configure(
|
||||
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
|
||||
# )
|
||||
# assert result2["type"] == FlowResultType.FORM
|
||||
|
||||
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
|
||||
# result3 = await hass.config_entries.flow.async_configure(
|
||||
# result2["flow_id"],
|
||||
# TEST_DATA,
|
||||
# )
|
||||
# await hass.async_block_till_done()
|
||||
|
||||
# assert result3["type"] == "create_entry"
|
||||
# assert result3["title"] == ""
|
||||
# assert result3["data"] == {
|
||||
# # ACCOUNT_ID: TEST_DATA["account_id"],
|
||||
# # CONF_PASSWORD: TEST_DATA["password"],
|
||||
# # CONNECTION_TYPE: CLOUD,
|
||||
# }
|
||||
# assert result3["options"] == {}
|
||||
# assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@pytest.fixture
|
||||
def validate_connections_mock():
|
||||
validate_mock = MagicMock()
|
||||
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
|
||||
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
|
||||
yield validate_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry():
|
||||
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
|
||||
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
|
||||
yield mock_setup_entry
|
||||
|
||||
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
|
||||
def _schema(hass: HomeAssistant, backend: str, options: dict | None = None):
|
||||
return local_llama_config_option_schema(
|
||||
hass=hass,
|
||||
language="en",
|
||||
options=options or {},
|
||||
backend_type=backend,
|
||||
subentry_type="conversation",
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
def _get_default(schema: dict, key_name: str):
|
||||
for key in schema:
|
||||
if getattr(key, "schema", None) == key_name:
|
||||
default = getattr(key, "default", None)
|
||||
return default() if callable(default) else default
|
||||
raise AssertionError(f"Key {key_name} not found in schema")
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
def _get_suggested(schema: dict, key_name: str):
|
||||
for key in schema:
|
||||
if getattr(key, "schema", None) == key_name:
|
||||
return (getattr(key, "description", {}) or {}).get("suggested_value")
|
||||
raise AssertionError(f"Key {key_name} not found in schema")
|
||||
|
||||
|
||||
def test_schema_llama_cpp_defaults_and_overrides(hass: HomeAssistant):
|
||||
overrides = {
|
||||
CONF_CONTEXT_LENGTH: 4096,
|
||||
CONF_LLAMACPP_BATCH_SIZE: 8,
|
||||
CONF_LLAMACPP_THREAD_COUNT: 6,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT: 3,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: True,
|
||||
CONF_PROMPT_CACHING_INTERVAL: 15,
|
||||
CONF_TOP_K: 12,
|
||||
CONF_TOOL_CALL_PREFIX: "<tc>",
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP, overrides)
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
expected_keys = {
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
CONF_LLAMACPP_THREAD_COUNT,
|
||||
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert expected_keys.issubset({getattr(k, "schema", None) for k in schema})
|
||||
|
||||
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_default(schema, CONF_LLAMACPP_BATCH_SIZE) == DEFAULT_LLAMACPP_BATCH_SIZE
|
||||
assert _get_default(schema, CONF_LLAMACPP_THREAD_COUNT) == DEFAULT_LLAMACPP_THREAD_COUNT
|
||||
assert _get_default(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == DEFAULT_LLAMACPP_BATCH_THREAD_COUNT
|
||||
assert _get_default(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION
|
||||
assert _get_default(schema, CONF_PROMPT_CACHING_INTERVAL) == DEFAULT_PROMPT_CACHING_INTERVAL
|
||||
# suggested values should reflect overrides
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 4096
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_SIZE) == 8
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_THREAD_COUNT) == 6
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_BATCH_THREAD_COUNT) == 3
|
||||
assert _get_suggested(schema, CONF_LLAMACPP_ENABLE_FLASH_ATTENTION) is True
|
||||
assert _get_suggested(schema, CONF_PROMPT_CACHING_INTERVAL) == 15
|
||||
assert _get_suggested(schema, CONF_TOP_K) == 12
|
||||
assert _get_suggested(schema, CONF_TOOL_CALL_PREFIX) == "<tc>"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
# simulate incorrect settings on first try
|
||||
validate_connections_mock.side_effect = [
|
||||
("failed_to_connect", Exception("ConnectionError"), []),
|
||||
(None, None, [])
|
||||
]
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert len(result3["errors"]) == 1
|
||||
assert "base" in result3["errors"]
|
||||
assert result3["step_id"] == "remote_model"
|
||||
|
||||
# retry
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TOP_K: DEFAULT_TOP_K,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_TYPICAL_P: DEFAULT_MIN_P,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
|
||||
def test_schema_text_gen_webui_options_preserved(hass: HomeAssistant):
|
||||
overrides = {
|
||||
CONF_REQUEST_TIMEOUT: 123,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET: "custom-preset",
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_CONTEXT_LENGTH: 2048,
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
schema = _schema(hass, BACKEND_TYPE_TEXT_GEN_WEBUI, overrides)
|
||||
|
||||
expected = {CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REQUEST_TIMEOUT, CONF_CONTEXT_LENGTH}
|
||||
assert expected.issubset({getattr(k, "schema", None) for k in schema})
|
||||
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 123
|
||||
assert _get_suggested(schema, CONF_TEXT_GEN_WEBUI_PRESET) == "custom-preset"
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 2048
|
||||
|
||||
|
||||
def test_schema_generic_openai_options_preserved(hass: HomeAssistant):
|
||||
overrides = {CONF_TOP_P: 0.25, CONF_REQUEST_TIMEOUT: 321}
|
||||
|
||||
schema = _schema(hass, BACKEND_TYPE_GENERIC_OPENAI, overrides)
|
||||
|
||||
assert {CONF_TOP_P, CONF_REQUEST_TIMEOUT}.issubset({getattr(k, "schema", None) for k in schema})
|
||||
assert _get_default(schema, CONF_TOP_P) == DEFAULT_TOP_P
|
||||
assert _get_default(schema, CONF_REQUEST_TIMEOUT) == DEFAULT_REQUEST_TIMEOUT
|
||||
assert _get_suggested(schema, CONF_TOP_P) == 0.25
|
||||
assert _get_suggested(schema, CONF_REQUEST_TIMEOUT) == 321
|
||||
# Base prompt options still present
|
||||
prompt_default = _get_default(schema, CONF_PROMPT)
|
||||
assert prompt_default is not None and "You are 'Al'" in prompt_default
|
||||
assert _get_default(schema, CONF_NUM_IN_CONTEXT_EXAMPLES) == DEFAULT_NUM_IN_CONTEXT_EXAMPLES
|
||||
|
||||
|
||||
def test_schema_llama_cpp_server_includes_gbnf(hass: HomeAssistant):
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP_SERVER)
|
||||
keys = {getattr(k, "schema", None) for k in schema}
|
||||
|
||||
assert {CONF_MAX_TOKENS, CONF_TOP_K, CONF_GBNF_GRAMMAR_FILE}.issubset(keys)
|
||||
assert _get_default(schema, CONF_GBNF_GRAMMAR_FILE) == "output.gbnf"
|
||||
|
||||
|
||||
def test_schema_ollama_defaults_and_overrides(hass: HomeAssistant):
|
||||
overrides = {CONF_OLLAMA_KEEP_ALIVE_MIN: 5, CONF_CONTEXT_LENGTH: 1024, CONF_TOP_K: 7}
|
||||
schema = _schema(hass, BACKEND_TYPE_OLLAMA, overrides)
|
||||
|
||||
assert {CONF_MAX_TOKENS, CONF_CONTEXT_LENGTH, CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE}.issubset(
|
||||
{getattr(k, "schema", None) for k in schema}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert _get_default(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == DEFAULT_OLLAMA_KEEP_ALIVE_MIN
|
||||
assert _get_default(schema, CONF_OLLAMA_JSON_MODE) is DEFAULT_OLLAMA_JSON_MODE
|
||||
assert _get_default(schema, CONF_CONTEXT_LENGTH) == DEFAULT_CONTEXT_LENGTH
|
||||
assert _get_default(schema, CONF_TOP_K) == DEFAULT_TOP_K
|
||||
assert _get_suggested(schema, CONF_OLLAMA_KEEP_ALIVE_MIN) == 5
|
||||
assert _get_suggested(schema, CONF_CONTEXT_LENGTH) == 1024
|
||||
assert _get_suggested(schema, CONF_TOP_K) == 7
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
|
||||
def test_schema_includes_llm_api_selector(monkeypatch, hass: HomeAssistant):
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.config_flow.llm.async_get_apis",
|
||||
lambda _hass: [type("API", (), {"id": "dummy", "name": "Dummy API", "tools": []})()],
|
||||
)
|
||||
schema = _schema(hass, BACKEND_TYPE_LLAMA_CPP)
|
||||
|
||||
def test_validate_options_schema(hass: HomeAssistant):
|
||||
|
||||
universal_options = [
|
||||
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
]
|
||||
|
||||
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
|
||||
assert set(options_llama_hf.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
|
||||
assert set(options_llama_existing.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
|
||||
assert set(options_ollama.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
|
||||
assert set(options_text_gen_webui.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
|
||||
assert set(options_generic_openai.keys()) == set(universal_options + [
|
||||
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
|
||||
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
assert _get_default(schema, CONF_LLM_HASS_API) is None
|
||||
# Base prompt and thinking prefixes use defaults when not overridden
|
||||
prompt_default = _get_default(schema, CONF_PROMPT)
|
||||
assert prompt_default is not None and "You are 'Al'" in prompt_default
|
||||
assert _get_default(schema, CONF_THINKING_PREFIX) == DEFAULT_THINKING_PREFIX
|
||||
assert _get_default(schema, CONF_TOOL_CALL_PREFIX) == DEFAULT_TOOL_CALL_PREFIX
|
||||
|
||||
114
tests/llama_conversation/test_conversation_agent.py
Normal file
114
tests/llama_conversation/test_conversation_agent.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Tests for LocalLLMAgent async_process."""
|
||||
|
||||
import pytest
|
||||
from contextlib import contextmanager
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, SystemContent, AssistantContent
|
||||
from homeassistant.const import MATCH_ALL
|
||||
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, hass):
|
||||
self.hass = hass
|
||||
self.generated_prompts = []
|
||||
|
||||
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
|
||||
self.generated_prompts.append(prompt_template)
|
||||
return "rendered-system-prompt"
|
||||
|
||||
async def _async_generate(self, conv, agent_id, chat_log, entity_options):
|
||||
async def gen():
|
||||
yield AssistantContent(agent_id=agent_id, content="hello from llm")
|
||||
return gen()
|
||||
|
||||
|
||||
class DummySubentry:
|
||||
def __init__(self, subentry_id="sub1", title="Test Agent", chat_model="model"):
|
||||
self.subentry_id = subentry_id
|
||||
self.title = title
|
||||
self.subentry_type = DOMAIN
|
||||
self.data = {CONF_CHAT_MODEL: chat_model}
|
||||
|
||||
|
||||
class DummyEntry:
|
||||
def __init__(self, entry_id="entry1", options=None, subentry=None, runtime_data=None):
|
||||
self.entry_id = entry_id
|
||||
self.options = options or {}
|
||||
self.subentries = {subentry.subentry_id: subentry}
|
||||
self.runtime_data = runtime_data
|
||||
|
||||
def add_update_listener(self, _cb):
|
||||
return lambda: None
|
||||
|
||||
|
||||
class FakeChatLog:
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
self.llm_api = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeChatSession:
|
||||
def __enter__(self):
|
||||
return {}
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_process_generates_response(monkeypatch, hass):
|
||||
client = DummyClient(hass)
|
||||
subentry = DummySubentry()
|
||||
entry = DummyEntry(subentry=subentry, runtime_data=client)
|
||||
|
||||
# Make entry discoverable through hass data as LocalLLMEntity expects.
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
@contextmanager
|
||||
def fake_chat_session(_hass, _conversation_id):
|
||||
yield FakeChatSession()
|
||||
|
||||
@contextmanager
|
||||
def fake_chat_log(_hass, _session, _user_input):
|
||||
yield FakeChatLog()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.conversation.chat_session.async_get_chat_session",
|
||||
fake_chat_session,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.conversation.conversation.async_get_chat_log",
|
||||
fake_chat_log,
|
||||
)
|
||||
|
||||
agent = LocalLLMAgent(hass, entry, subentry, client)
|
||||
|
||||
result = await agent.async_process(
|
||||
ConversationInput(
|
||||
text="turn on the lights",
|
||||
context=None,
|
||||
conversation_id="conv-id",
|
||||
device_id=None,
|
||||
language="en",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.response.speech["plain"]["speech"] == "hello from llm"
|
||||
# System prompt should be rendered once when message history is empty.
|
||||
assert client.generated_prompts == [DEFAULT_PROMPT]
|
||||
assert agent.supported_languages == MATCH_ALL
|
||||
162
tests/llama_conversation/test_entity.py
Normal file
162
tests/llama_conversation/test_entity.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Tests for LocalLLMClient helpers in entity.py."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import pytest
|
||||
from json import JSONDecodeError
|
||||
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_TOOL_CALL_PREFIX,
|
||||
DEFAULT_TOOL_CALL_SUFFIX,
|
||||
DEFAULT_THINKING_PREFIX,
|
||||
DEFAULT_THINKING_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
class DummyLocalClient(LocalLLMClient):
|
||||
@staticmethod
|
||||
def get_name(_client_options):
|
||||
return "dummy"
|
||||
|
||||
|
||||
class DummyLLMApi:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(hass):
|
||||
# Disable ICL loading during tests to avoid filesystem access.
|
||||
return DummyLocalClient(hass, {CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_parses_tool_call(client):
|
||||
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":0.5,"to_say":" acknowledged"}}'
|
||||
completion = (
|
||||
f"{DEFAULT_THINKING_PREFIX}internal{DEFAULT_THINKING_SUFFIX}"
|
||||
f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
)
|
||||
|
||||
result = await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, completion)
|
||||
|
||||
assert result.response.strip().startswith("hello")
|
||||
assert "acknowledged" in result.response
|
||||
assert result.tool_calls
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.tool_name == "light.turn_on"
|
||||
assert tool_call.tool_args["brightness"] == 127
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_ignores_tools_without_llm_api(client):
|
||||
raw_tool = '{"name":"light.turn_on","arguments":{"brightness":1}}'
|
||||
completion = f"hello {DEFAULT_TOOL_CALL_PREFIX}{raw_tool}{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
|
||||
result = await client._async_parse_completion(None, "agent-id", {}, completion)
|
||||
|
||||
assert result.tool_calls == []
|
||||
assert result.response.strip() == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parse_completion_malformed_tool_raises(client):
|
||||
bad_tool = f"{DEFAULT_TOOL_CALL_PREFIX}{{not-json{DEFAULT_TOOL_CALL_SUFFIX}"
|
||||
|
||||
with pytest.raises(JSONDecodeError):
|
||||
await client._async_parse_completion(DummyLLMApi(), "agent-id", {}, bad_tool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_handles_streamed_tool_call(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (
|
||||
None,
|
||||
[
|
||||
{
|
||||
"function": {
|
||||
"name": "light.turn_on",
|
||||
"arguments": {"brightness": 0.25, "to_say": " ok"},
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
stream = client._async_stream_parse_completion(
|
||||
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
|
||||
)
|
||||
|
||||
results = [chunk async for chunk in stream]
|
||||
|
||||
assert results[0].response == "Hi"
|
||||
assert results[1].response.strip() == "ok"
|
||||
assert results[1].tool_calls[0].tool_args["brightness"] == 63
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_malformed_tool_raises(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (None, ["{not-json"])
|
||||
|
||||
with pytest.raises(JSONDecodeError):
|
||||
async for _chunk in client._async_stream_parse_completion(
|
||||
DummyLLMApi(), "agent-id", {}, anext_token=token_generator()
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_parse_completion_ignores_tools_without_llm_api(client):
|
||||
async def token_generator():
|
||||
yield ("Hi", None)
|
||||
yield (None, ["{}"])
|
||||
|
||||
results = [chunk async for chunk in client._async_stream_parse_completion(
|
||||
None, "agent-id", {}, anext_token=token_generator()
|
||||
)]
|
||||
|
||||
assert results[0].response == "Hi"
|
||||
assert results[1].tool_calls is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_get_exposed_entities_respects_exposure(monkeypatch, client, hass):
|
||||
hass.states.async_set("light.exposed", "on", {"friendly_name": "Lamp"})
|
||||
hass.states.async_set("switch.hidden", "off", {"friendly_name": "Hidden"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.entity.async_should_expose",
|
||||
lambda _hass, _domain, entity_id: not entity_id.endswith("hidden"),
|
||||
)
|
||||
|
||||
exposed = client._async_get_exposed_entities()
|
||||
|
||||
assert "light.exposed" in exposed
|
||||
assert "switch.hidden" not in exposed
|
||||
assert exposed["light.exposed"]["friendly_name"] == "Lamp"
|
||||
assert exposed["light.exposed"]["state"] == "on"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_system_prompt_renders(monkeypatch, client, hass):
|
||||
hass.states.async_set("light.kitchen", "on", {"friendly_name": "Kitchen"})
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.entity.async_should_expose",
|
||||
lambda _hass, _domain, _entity_id: True,
|
||||
)
|
||||
|
||||
rendered = client._generate_system_prompt(
|
||||
"Devices:\n{{ formatted_devices }}",
|
||||
llm_api=None,
|
||||
entity_options={CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: []},
|
||||
)
|
||||
if inspect.iscoroutine(rendered):
|
||||
rendered = await rendered
|
||||
|
||||
assert isinstance(rendered, str)
|
||||
assert "light.kitchen" in rendered
|
||||
159
tests/llama_conversation/test_migrations.py
Normal file
159
tests/llama_conversation/test_migrations.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Regression tests for config entry migration in __init__.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from pytest_homeassistant_custom_component.common import MockConfigEntry
|
||||
|
||||
from custom_components.llama_conversation import async_migrate_entry
|
||||
from custom_components.llama_conversation.const import (
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_PROMPT,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v1_is_rejected(hass):
|
||||
entry = MockConfigEntry(domain=DOMAIN, data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP}, version=1)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v2_creates_subentry_and_updates_entry(monkeypatch, hass):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
title="llama 'Test Agent' entry",
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
|
||||
options={
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "8080",
|
||||
CONF_SSL: False,
|
||||
CONF_GENERIC_OPENAI_PATH: "v1",
|
||||
CONF_PROMPT: "hello",
|
||||
CONF_REQUEST_TIMEOUT: 90,
|
||||
CONF_CHAT_MODEL: "model-x",
|
||||
CONF_CONTEXT_LENGTH: 1024,
|
||||
},
|
||||
version=2,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
added_subentries = []
|
||||
update_calls = []
|
||||
|
||||
def fake_add_subentry(cfg_entry, subentry):
|
||||
added_subentries.append((cfg_entry, subentry))
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
update_calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(hass.config_entries, "async_add_subentry", fake_add_subentry)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
assert added_subentries, "Subentry should be added"
|
||||
subentry = added_subentries[0][1]
|
||||
assert isinstance(subentry, ConfigSubentry)
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data[CONF_CHAT_MODEL] == "model-x"
|
||||
# Entry should be updated to version 3 with data/options separated
|
||||
assert any(call.get("version") == 3 for call in update_calls)
|
||||
last_options = [c["options"] for c in update_calls if "options" in c][-1]
|
||||
assert last_options[CONF_HOST] == "localhost"
|
||||
assert CONF_PROMPT not in last_options # moved to subentry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v3_minor0_downloads_model(monkeypatch, hass):
|
||||
sub_data = {
|
||||
CONF_CHAT_MODEL: "model-a",
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION: "Q4_K_M",
|
||||
CONF_REQUEST_TIMEOUT: 30,
|
||||
}
|
||||
subentry = ConfigSubentry(data=sub_data, subentry_type="conversation", title="sub", unique_id=None)
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_CPP},
|
||||
options={},
|
||||
version=3,
|
||||
minor_version=0,
|
||||
)
|
||||
entry.subentries = {"sub": subentry}
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
updated_subentries = []
|
||||
update_calls = []
|
||||
|
||||
def fake_update_subentry(cfg_entry, old_sub, *, data=None, **_kwargs):
|
||||
updated_subentries.append((cfg_entry, old_sub, data))
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
update_calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"custom_components.llama_conversation.download_model_from_hf", lambda *_args, **_kw: "file.gguf"
|
||||
)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_subentry", fake_update_subentry)
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
assert updated_subentries, "Subentry should be updated with downloaded file"
|
||||
new_data = updated_subentries[0][2]
|
||||
assert new_data[CONF_DOWNLOADED_MODEL_FILE] == "file.gguf"
|
||||
assert any(call.get("minor_version") == 1 for call in update_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_value,expected_list",
|
||||
[("api-1", ["api-1"]), (None, [])],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_v3_minor1_converts_api_to_list(monkeypatch, hass, api_value, expected_list):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI},
|
||||
options={CONF_LLM_HASS_API: api_value},
|
||||
version=3,
|
||||
minor_version=1,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_update_entry(cfg_entry, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if "options" in kwargs:
|
||||
cfg_entry._options = kwargs["options"] # type: ignore[attr-defined]
|
||||
if "minor_version" in kwargs:
|
||||
cfg_entry._minor_version = kwargs["minor_version"] # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr(hass.config_entries, "async_update_entry", fake_update_entry)
|
||||
|
||||
result = await async_migrate_entry(hass, entry)
|
||||
|
||||
assert result is True
|
||||
options_calls = [c for c in calls if "options" in c]
|
||||
assert options_calls, "async_update_entry should be called with options"
|
||||
assert options_calls[-1]["options"][CONF_LLM_HASS_API] == expected_list
|
||||
|
||||
minor_calls = [c for c in calls if c.get("minor_version")]
|
||||
assert minor_calls and minor_calls[-1]["minor_version"] == 2
|
||||
Reference in New Issue
Block a user