fix some tests

This commit is contained in:
Alex O'Connell
2024-06-08 16:08:50 -04:00
parent 127ea425f7
commit e527366506
3 changed files with 70 additions and 55 deletions

View File

@@ -394,10 +394,12 @@ class LocalLLMAgent(AbstractConversationAgent):
try:
tool_response = await llm_api.async_call_tool(tool_input)
_LOGGER.debug("Tool response: %s", tool_response)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
@@ -408,8 +410,6 @@ class LocalLLMAgent(AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id
)
_LOGGER.debug("Tool response: %s", tool_response)
# handle models that generate a function call and wait for the result before providing a response
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT):
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
@@ -436,7 +436,7 @@ class LocalLLMAgent(AbstractConversationAgent):
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say)
intent_response.async_set_speech(to_say.strip())
return ConversationResult(
response=intent_response, conversation_id=conversation_id
)

View File

@@ -86,8 +86,10 @@ from homeassistant.components.conversation import ConversationInput
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL
CONF_SSL,
CONF_LLM_HASS_API
)
from homeassistant.helpers.llm import LLM_API_ASSIST, APIInstance
_LOGGER = logging.getLogger(__name__)
@@ -122,26 +124,19 @@ def config_entry():
},
options={
**DEFAULT_OPTIONS,
CONF_LLM_HASS_API: LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_SERVICE_CALL_REGEX: r"({[\S \t]*})"
}
)
@pytest.fixture
def home_assistant_mock():
mock_home_assistant = MagicMock()
async def call_now(func, *args, **kwargs):
return func(*args, **kwargs)
mock_home_assistant.async_add_executor_job.side_effect = call_now
mock_home_assistant.services.async_call = AsyncMock()
yield mock_home_assistant
@pytest.fixture
def local_llama_agent_fixture(config_entry, home_assistant_mock):
def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
@@ -166,14 +161,17 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
generate_mock.return_value = list(range(20))
detokenize_mock = llama_instance_mock.detokenize
detokenize_mock.return_value = json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
}).encode()
detokenize_mock.return_value = ("I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
})).encode()
call_tool_mock.return_value = {"result": "success"}
agent_obj = LlamaCppAgent(
home_assistant_mock,
hass,
config_entry
)
@@ -268,10 +266,11 @@ async def test_local_llama_agent(local_llama_agent_fixture):
)
@pytest.fixture
def ollama_agent_fixture(config_entry, home_assistant_mock):
def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
@@ -291,8 +290,10 @@ def ollama_agent_fixture(config_entry, home_assistant_mock):
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
requests_get_mock.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = OllamaAPIAgent(
home_assistant_mock,
hass,
config_entry
)
@@ -318,10 +319,11 @@ async def test_ollama_agent(ollama_agent_fixture):
response_mock.json.return_value = {
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
"created_at": "2023-11-09T21:07:55.186497Z",
"response": json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
"response": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"done": True,
"context": [1, 2, 3],
@@ -410,10 +412,11 @@ async def test_ollama_agent(ollama_agent_fixture):
@pytest.fixture
def text_generation_webui_agent_fixture(config_entry, home_assistant_mock):
def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
@@ -433,8 +436,10 @@ def text_generation_webui_agent_fixture(config_entry, home_assistant_mock):
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
requests_get_mock.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
agent_obj = TextGenerationWebuiAgent(
home_assistant_mock,
hass,
config_entry
)
@@ -464,10 +469,11 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
@@ -559,10 +565,11 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"index": 0,
"message": {
"role": "assistant",
"content": json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,
@@ -669,10 +676,11 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
)
@pytest.fixture
def generic_openai_agent_fixture(config_entry, home_assistant_mock):
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
@@ -688,8 +696,10 @@ def generic_openai_agent_fixture(config_entry, home_assistant_mock):
["light", "switch", "fan"]
)
call_tool_mock.return_value = {"result": "success"}
agent_obj = GenericOpenAIAPIAgent(
home_assistant_mock,
hass,
config_entry
)
@@ -714,10 +724,11 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
"model": "gpt-3.5-turbo-instruct",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"text": json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
"text": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
"index": 0,
"logprobs": None,
@@ -774,10 +785,11 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
"index": 0,
"message": {
"role": "assistant",
"content": json.dumps({
"to_say": "I am saying something!",
"service": "light.turn_on",
"target_device": "light.kitchen_light",
"content": "I am saying something!\n" + json.dumps({
"name": "HassTurnOn",
"arguments": {
"name": "light.kitchen_light"
}
}),
},
"logprobs": None,

View File

@@ -7,6 +7,7 @@ from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
CONF_LLM_HASS_API,
)
from homeassistant.data_entry_flow import FlowResultType
@@ -25,6 +26,8 @@ from custom_components.llama_conversation.const import (
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_PROMPT_TEMPLATE,
CONF_TOOL_FORMAT,
CONF_TOOL_MULTI_TURN_CHAT,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
@@ -290,16 +293,16 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
def test_validate_options_schema():
def test_validate_options_schema(hass: HomeAssistant):
universal_options = [
CONF_PROMPT, CONF_PROMPT_TEMPLATE,
CONF_LLM_HASS_API, CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, CONF_TOOL_MULTI_TURN_CHAT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
options_llama_hf = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
@@ -308,7 +311,7 @@ def test_validate_options_schema():
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
options_llama_existing = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
@@ -317,7 +320,7 @@ def test_validate_options_schema():
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_ollama = local_llama_config_option_schema(None, BACKEND_TYPE_OLLAMA)
options_ollama = local_llama_config_option_schema(hass, None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
@@ -325,7 +328,7 @@ def test_validate_options_schema():
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_text_gen_webui = local_llama_config_option_schema(None, BACKEND_TYPE_TEXT_GEN_WEBUI)
options_text_gen_webui = local_llama_config_option_schema(hass, None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
@@ -333,13 +336,13 @@ def test_validate_options_schema():
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_generic_openai = local_llama_config_option_schema(None, BACKEND_TYPE_GENERIC_OPENAI)
options_generic_openai = local_llama_config_option_schema(hass, None, BACKEND_TYPE_GENERIC_OPENAI)
assert set(options_generic_openai.keys()) == set(universal_options + [
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_llama_cpp_python_server = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
options_llama_cpp_python_server = local_llama_config_option_schema(hass, None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF