mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
Split backends into separate files and start implementing streaming + tool support
This commit is contained in:
@@ -4,7 +4,10 @@ import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.conversation import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
|
||||
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -106,7 +109,8 @@ class MockConfigEntry:
|
||||
def __init__(self, entry_id='test_entry_id', data={}, options={}):
|
||||
self.entry_id = entry_id
|
||||
self.data = WarnDict(data)
|
||||
self.options = WarnDict(options)
|
||||
# Use a mutable dict for options in tests
|
||||
self.options = WarnDict(dict(options))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -138,14 +142,16 @@ def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
|
||||
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
llama_instance_mock = MagicMock()
|
||||
llama_class_mock = MagicMock()
|
||||
llama_class_mock.return_value = llama_instance_mock
|
||||
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
|
||||
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
|
||||
install_llama_cpp_python_mock.return_value = True
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
@@ -194,7 +200,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -249,7 +255,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -272,8 +278,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -288,7 +293,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
get_clientsession.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
@@ -298,8 +303,8 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession.get,
|
||||
"requests_post": get_clientsession.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -339,7 +344,7 @@ async def test_ollama_agent(ollama_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -384,7 +389,7 @@ async def test_ollama_agent(ollama_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -418,8 +423,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -434,7 +438,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
get_clientsession_mock.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
@@ -444,8 +448,8 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -490,7 +494,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -521,7 +525,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -588,7 +592,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -611,69 +615,6 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Character"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"character": "Some Character",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE] = TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT
|
||||
|
||||
# do another turn of the same conversation and use instruct mode
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"instruction_template": "chatml",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
@@ -682,8 +623,7 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
|
||||
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -704,8 +644,8 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -745,7 +685,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -804,7 +744,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
Reference in New Issue
Block a user