mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-07 21:04:08 -05:00
more tests, fix missing default options, and load ICL as utf8
This commit is contained in:
@@ -137,7 +137,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
try:
|
||||
icl_filename = os.path.join(os.path.dirname(__file__), filename)
|
||||
|
||||
with open(icl_filename) as f:
|
||||
with open(icl_filename, encoding="utf-8-sig") as f:
|
||||
self.in_context_examples = list(csv.DictReader(f))
|
||||
|
||||
if set(self.in_context_examples[0].keys()) != set(["service", "response" ]):
|
||||
|
||||
@@ -162,16 +162,21 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE: DEFAULT_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT: DEFAULT_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT: DEFAULT_BATCH_THREAD_COUNT,
|
||||
CONF_PROMPT_CACHING_ENABLED: DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
2
pytest.ini
Normal file
2
pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
@@ -12,4 +12,5 @@ homeassistant
|
||||
hassil
|
||||
home-assistant-intents
|
||||
pytest
|
||||
pytest-asynciop
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component
|
||||
@@ -78,7 +78,6 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_OPTIONS,
|
||||
)
|
||||
|
||||
import homeassistant.helpers.template
|
||||
from homeassistant.components.conversation import ConversationInput
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
@@ -102,19 +101,6 @@ class MockConfigEntry:
|
||||
self.entry_id = entry_id
|
||||
self.data = WarnDict(data)
|
||||
self.options = WarnDict(options)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def patch_dependency_group1():
|
||||
# with patch('path.to.dependency1') as mock1, \
|
||||
# patch('path.to.dependency2') as mock2:
|
||||
# yield mock1, mock2
|
||||
|
||||
# @pytest.fixture
|
||||
# def patch_dependency_group2():
|
||||
# with patch('path.to.dependency3') as mock3, \
|
||||
# patch('path.to.dependency4') as mock4:
|
||||
# yield mock3, mock4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -195,34 +181,77 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
# TODO: test base llama agent (ICL loading other languages)
|
||||
|
||||
@pytest.mark.asyncio # This decorator is necessary for pytest to run async test functions
|
||||
async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
local_llama_agent: LocalLLaMAAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
local_llama_agent, all_mocks = local_llama_agent_fixture
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=ANY,
|
||||
n_ctx=ANY,
|
||||
n_batch=ANY,
|
||||
n_threads=ANY,
|
||||
n_threads_batch=ANY,
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
)
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=ANY,
|
||||
top_k=ANY,
|
||||
top_p=ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
grammar=ANY,
|
||||
)
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options then apply them
|
||||
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
|
||||
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
local_llama_agent.entry.options[CONF_TOP_K] = 20
|
||||
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
|
||||
local_llama_agent._update_options()
|
||||
|
||||
all_mocks["llama_class"].assert_called_once_with(
|
||||
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
|
||||
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
)
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
all_mocks["generate"].assert_called_once_with(
|
||||
ANY,
|
||||
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
|
||||
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
|
||||
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
# TODO: test backends: text-gen-webui, ollama, generic openai
|
||||
@@ -1,21 +1,342 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# @pytest.fixture
|
||||
# def patch_dependency_group1():
|
||||
# with patch('path.to.dependency1') as mock1, \
|
||||
# patch('path.to.dependency2') as mock2:
|
||||
# yield mock1, mock2
|
||||
from homeassistant import config_entries, setup
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
# @pytest.fixture
|
||||
# def patch_dependency_group2():
|
||||
# with patch('path.to.dependency3') as mock3, \
|
||||
# patch('path.to.dependency4') as mock4:
|
||||
# yield mock3, mock4
|
||||
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
|
||||
# result = await hass.config_entries.flow.async_init(
|
||||
# DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
# )
|
||||
# assert result["type"] == FlowResultType.FORM
|
||||
# assert result["errors"] is None
|
||||
|
||||
def test_validate_setup_schemas():
|
||||
pass
|
||||
# result2 = await hass.config_entries.flow.async_configure(
|
||||
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
|
||||
# )
|
||||
# assert result2["type"] == FlowResultType.FORM
|
||||
|
||||
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
|
||||
# result3 = await hass.config_entries.flow.async_configure(
|
||||
# result2["flow_id"],
|
||||
# TEST_DATA,
|
||||
# )
|
||||
# await hass.async_block_till_done()
|
||||
|
||||
# assert result3["type"] == "create_entry"
|
||||
# assert result3["title"] == ""
|
||||
# assert result3["data"] == {
|
||||
# # ACCOUNT_ID: TEST_DATA["account_id"],
|
||||
# # CONF_PASSWORD: TEST_DATA["password"],
|
||||
# # CONNECTION_TYPE: CLOUD,
|
||||
# }
|
||||
# assert result3["options"] == {}
|
||||
# assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@pytest.fixture
|
||||
def validate_connections_mock():
|
||||
validate_mock = MagicMock()
|
||||
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
|
||||
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
|
||||
yield validate_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry():
|
||||
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
|
||||
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
|
||||
yield mock_setup_entry
|
||||
|
||||
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
assert result["step_id"] == "pick_backend"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {}
|
||||
assert result2["step_id"] == "remote_model"
|
||||
|
||||
# simulate incorrect settings on first try
|
||||
validate_connections_mock.side_effect = [
|
||||
("failed_to_connect", Exception("ConnectionError")),
|
||||
(None, None)
|
||||
]
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5000",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert len(result3["errors"]) == 1
|
||||
assert "base" in result3["errors"]
|
||||
assert result3["step_id"] == "remote_model"
|
||||
|
||||
# retry
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {}
|
||||
assert result3["step_id"] == "model_parameters"
|
||||
|
||||
options_dict = {
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TOP_K: DEFAULT_TOP_K,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
|
||||
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], options_dict
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result4["type"] == "create_entry"
|
||||
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
|
||||
assert result4["data"] == {
|
||||
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
|
||||
CONF_HOST: "localhost",
|
||||
CONF_PORT: "5001",
|
||||
CONF_SSL: False,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
}
|
||||
assert result4["options"] == options_dict
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
def test_validate_options_schema():
|
||||
pass
|
||||
|
||||
universal_options = [
|
||||
CONF_PROMPT, CONF_PROMPT_TEMPLATE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
]
|
||||
|
||||
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
|
||||
assert set(options_llama_hf.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
|
||||
assert set(options_llama_existing.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
])
|
||||
|
||||
options_ollama = local_llama_config_option_schema(None, BACKEND_TYPE_OLLAMA)
|
||||
assert set(options_ollama.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_text_gen_webui = local_llama_config_option_schema(None, BACKEND_TYPE_TEXT_GEN_WEBUI)
|
||||
assert set(options_text_gen_webui.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_generic_openai = local_llama_config_option_schema(None, BACKEND_TYPE_GENERIC_OPENAI)
|
||||
assert set(options_generic_openai.keys()) == set(universal_options + [
|
||||
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
|
||||
options_llama_cpp_python_server = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
|
||||
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
|
||||
])
|
||||
Reference in New Issue
Block a user