mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
349 lines
14 KiB
Python
349 lines
14 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from homeassistant import config_entries, setup
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.const import (
|
|
CONF_HOST,
|
|
CONF_PORT,
|
|
CONF_SSL,
|
|
)
|
|
from homeassistant.data_entry_flow import FlowResultType
|
|
|
|
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
|
|
from custom_components.llama_conversation.const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_PROMPT,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_K,
|
|
CONF_TOP_P,
|
|
CONF_MIN_P,
|
|
CONF_TYPICAL_P,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_BACKEND_TYPE,
|
|
CONF_DOWNLOADED_MODEL_FILE,
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
CONF_PROMPT_TEMPLATE,
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
CONF_TEXT_GEN_WEBUI_PRESET,
|
|
CONF_OPENAI_API_KEY,
|
|
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
|
CONF_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_PROMPT_CACHING_ENABLED,
|
|
CONF_PROMPT_CACHING_INTERVAL,
|
|
CONF_SERVICE_CALL_REGEX,
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE,
|
|
CONF_CONTEXT_LENGTH,
|
|
CONF_BATCH_SIZE,
|
|
CONF_THREAD_COUNT,
|
|
CONF_BATCH_THREAD_COUNT,
|
|
BACKEND_TYPE_LLAMA_HF,
|
|
BACKEND_TYPE_LLAMA_EXISTING,
|
|
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
|
BACKEND_TYPE_GENERIC_OPENAI,
|
|
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
|
BACKEND_TYPE_OLLAMA,
|
|
DEFAULT_CHAT_MODEL,
|
|
DEFAULT_PROMPT,
|
|
DEFAULT_MAX_TOKENS,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOP_K,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_MIN_P,
|
|
DEFAULT_TYPICAL_P,
|
|
DEFAULT_BACKEND_TYPE,
|
|
DEFAULT_REQUEST_TIMEOUT,
|
|
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
DEFAULT_PROMPT_TEMPLATE,
|
|
DEFAULT_USE_GBNF_GRAMMAR,
|
|
DEFAULT_GBNF_GRAMMAR_FILE,
|
|
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
|
DEFAULT_REMEMBER_CONVERSATION,
|
|
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
|
DEFAULT_PROMPT_CACHING_ENABLED,
|
|
DEFAULT_PROMPT_CACHING_INTERVAL,
|
|
DEFAULT_SERVICE_CALL_REGEX,
|
|
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
|
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
DEFAULT_OLLAMA_JSON_MODE,
|
|
DEFAULT_CONTEXT_LENGTH,
|
|
DEFAULT_BATCH_SIZE,
|
|
DEFAULT_THREAD_COUNT,
|
|
DEFAULT_BATCH_THREAD_COUNT,
|
|
DOMAIN,
|
|
)
|
|
|
|
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
|
|
# result = await hass.config_entries.flow.async_init(
|
|
# DOMAIN, context={"source": config_entries.SOURCE_USER}
|
|
# )
|
|
# assert result["type"] == FlowResultType.FORM
|
|
# assert result["errors"] is None
|
|
|
|
# result2 = await hass.config_entries.flow.async_configure(
|
|
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
|
|
# )
|
|
# assert result2["type"] == FlowResultType.FORM
|
|
|
|
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
|
|
# result3 = await hass.config_entries.flow.async_configure(
|
|
# result2["flow_id"],
|
|
# TEST_DATA,
|
|
# )
|
|
# await hass.async_block_till_done()
|
|
|
|
# assert result3["type"] == "create_entry"
|
|
# assert result3["title"] == ""
|
|
# assert result3["data"] == {
|
|
# # ACCOUNT_ID: TEST_DATA["account_id"],
|
|
# # CONF_PASSWORD: TEST_DATA["password"],
|
|
# # CONNECTION_TYPE: CLOUD,
|
|
# }
|
|
# assert result3["options"] == {}
|
|
# assert len(mock_setup_entry.mock_calls) == 1
|
|
|
|
@pytest.fixture
|
|
def validate_connections_mock():
|
|
validate_mock = MagicMock()
|
|
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
|
|
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
|
|
yield validate_mock
|
|
|
|
@pytest.fixture
|
|
def mock_setup_entry():
|
|
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
|
|
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
|
|
yield mock_setup_entry
|
|
|
|
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
|
|
result = await hass.config_entries.flow.async_init(
|
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
|
)
|
|
assert result["type"] == FlowResultType.FORM
|
|
assert result["errors"] == {}
|
|
assert result["step_id"] == "pick_backend"
|
|
|
|
result2 = await hass.config_entries.flow.async_configure(
|
|
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
|
|
)
|
|
|
|
assert result2["type"] == FlowResultType.FORM
|
|
assert result2["errors"] == {}
|
|
assert result2["step_id"] == "remote_model"
|
|
|
|
result3 = await hass.config_entries.flow.async_configure(
|
|
result2["flow_id"],
|
|
{
|
|
CONF_HOST: "localhost",
|
|
CONF_PORT: "5000",
|
|
CONF_SSL: False,
|
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
|
},
|
|
)
|
|
|
|
assert result3["type"] == FlowResultType.FORM
|
|
assert result3["errors"] == {}
|
|
assert result3["step_id"] == "model_parameters"
|
|
|
|
options_dict = {
|
|
CONF_PROMPT: DEFAULT_PROMPT,
|
|
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
|
CONF_TOP_P: DEFAULT_TOP_P,
|
|
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
|
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
|
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
}
|
|
|
|
result4 = await hass.config_entries.flow.async_configure(
|
|
result2["flow_id"], options_dict
|
|
)
|
|
await hass.async_block_till_done()
|
|
|
|
assert result4["type"] == "create_entry"
|
|
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
|
assert result4["data"] == {
|
|
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
|
|
CONF_HOST: "localhost",
|
|
CONF_PORT: "5000",
|
|
CONF_SSL: False,
|
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
|
}
|
|
assert result4["options"] == options_dict
|
|
assert len(mock_setup_entry.mock_calls) == 1
|
|
|
|
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
|
|
result = await hass.config_entries.flow.async_init(
|
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
|
)
|
|
assert result["type"] == FlowResultType.FORM
|
|
assert result["errors"] == {}
|
|
assert result["step_id"] == "pick_backend"
|
|
|
|
result2 = await hass.config_entries.flow.async_configure(
|
|
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
|
|
)
|
|
|
|
assert result2["type"] == FlowResultType.FORM
|
|
assert result2["errors"] == {}
|
|
assert result2["step_id"] == "remote_model"
|
|
|
|
# simulate incorrect settings on first try
|
|
validate_connections_mock.side_effect = [
|
|
("failed_to_connect", Exception("ConnectionError"), []),
|
|
(None, None, [])
|
|
]
|
|
|
|
result3 = await hass.config_entries.flow.async_configure(
|
|
result2["flow_id"],
|
|
{
|
|
CONF_HOST: "localhost",
|
|
CONF_PORT: "5000",
|
|
CONF_SSL: False,
|
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
|
},
|
|
)
|
|
|
|
assert result3["type"] == FlowResultType.FORM
|
|
assert len(result3["errors"]) == 1
|
|
assert "base" in result3["errors"]
|
|
assert result3["step_id"] == "remote_model"
|
|
|
|
# retry
|
|
result3 = await hass.config_entries.flow.async_configure(
|
|
result2["flow_id"],
|
|
{
|
|
CONF_HOST: "localhost",
|
|
CONF_PORT: "5001",
|
|
CONF_SSL: False,
|
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
|
},
|
|
)
|
|
|
|
assert result3["type"] == FlowResultType.FORM
|
|
assert result3["errors"] == {}
|
|
assert result3["step_id"] == "model_parameters"
|
|
|
|
options_dict = {
|
|
CONF_PROMPT: DEFAULT_PROMPT,
|
|
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
|
CONF_TOP_P: DEFAULT_TOP_P,
|
|
CONF_TOP_K: DEFAULT_TOP_K,
|
|
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
|
CONF_TYPICAL_P: DEFAULT_MIN_P,
|
|
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
|
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
|
|
}
|
|
|
|
result4 = await hass.config_entries.flow.async_configure(
|
|
result2["flow_id"], options_dict
|
|
)
|
|
await hass.async_block_till_done()
|
|
|
|
assert result4["type"] == "create_entry"
|
|
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
|
assert result4["data"] == {
|
|
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
|
|
CONF_HOST: "localhost",
|
|
CONF_PORT: "5001",
|
|
CONF_SSL: False,
|
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
|
}
|
|
assert result4["options"] == options_dict
|
|
mock_setup_entry.assert_called_once()
|
|
|
|
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
|
|
|
|
def test_validate_options_schema():
|
|
|
|
universal_options = [
|
|
CONF_PROMPT, CONF_PROMPT_TEMPLATE,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
|
|
]
|
|
|
|
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
|
|
assert set(options_llama_hf.keys()) == set(universal_options + [
|
|
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
|
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
|
CONF_CONTEXT_LENGTH, # supports context length
|
|
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
|
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
|
])
|
|
|
|
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
|
|
assert set(options_llama_existing.keys()) == set(universal_options + [
|
|
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
|
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
|
CONF_CONTEXT_LENGTH, # supports context length
|
|
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
|
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
|
])
|
|
|
|
options_ollama = local_llama_config_option_schema(None, BACKEND_TYPE_OLLAMA)
|
|
assert set(options_ollama.keys()) == set(universal_options + [
|
|
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
|
|
CONF_CONTEXT_LENGTH, # supports context length
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
|
])
|
|
|
|
options_text_gen_webui = local_llama_config_option_schema(None, BACKEND_TYPE_TEXT_GEN_WEBUI)
|
|
assert set(options_text_gen_webui.keys()) == set(universal_options + [
|
|
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
|
|
CONF_CONTEXT_LENGTH, # supports context length
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
|
])
|
|
|
|
options_generic_openai = local_llama_config_option_schema(None, BACKEND_TYPE_GENERIC_OPENAI)
|
|
assert set(options_generic_openai.keys()) == set(universal_options + [
|
|
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
|
])
|
|
|
|
options_llama_cpp_python_server = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
|
|
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
|
|
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
|
|
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
|
]) |