mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
finish agent tests
This commit is contained in:
@@ -846,7 +846,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] == "chat.completion":
|
||||
return choices[0]["message"]["content"]
|
||||
@@ -972,7 +972,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -1069,7 +1069,7 @@ class OllamaAPIAgent(LLaMAAgent):
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
if response_json["done"] != "true":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
|
||||
@@ -168,7 +168,6 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
@@ -177,6 +176,10 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
CONF_THREAD_COUNT: DEFAULT_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT: DEFAULT_BATCH_THREAD_COUNT,
|
||||
CONF_PROMPT_CACHING_ENABLED: DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET: ""
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.agent import LocalLLaMAAgent
|
||||
from custom_components.llama_conversation.agent import LocalLLaMAAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -254,4 +254,531 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
grammar=ANY,
|
||||
)
|
||||
|
||||
# TODO: test backends: text-gen-webui, ollama, generic openai
|
||||
# TODO: test backends: text-gen-webui, ollama, generic openai
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_agent_fixture(config_entry, home_assistant_mock):
|
||||
with patch.object(OllamaAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(OllamaAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
|
||||
agent_obj = OllamaAPIAgent(
|
||||
home_assistant_mock,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_ollama_agent(ollama_agent_fixture):
|
||||
|
||||
ollama_agent: OllamaAPIAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
ollama_agent, all_mocks = ollama_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/api/tags",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": json.dumps({
|
||||
"to_say": "I am saying something!",
|
||||
"service": "light.turn_on",
|
||||
"target_device": "light.kitchen_light",
|
||||
}),
|
||||
"done": True,
|
||||
"context": [1, 2, 3],
|
||||
"total_duration": 4648158584,
|
||||
"load_duration": 4071084,
|
||||
"prompt_eval_count": 36,
|
||||
"prompt_eval_duration": 439038000,
|
||||
"eval_count": 180,
|
||||
"eval_duration": 4196918000
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/generate",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"prompt": ANY
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
ollama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
|
||||
ollama_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
ollama_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN] = 99
|
||||
ollama_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
ollama_agent.entry.options[CONF_OLLAMA_JSON_MODE] = True
|
||||
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
ollama_agent.entry.options[CONF_TOP_K] = 20
|
||||
ollama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/api/chat",
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
json={
|
||||
"model": ollama_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"keep_alive": f"{ollama_agent.entry.options[CONF_OLLAMA_KEEP_ALIVE_MIN]}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"top_p": ollama_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": ollama_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
|
||||
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
|
||||
},
|
||||
"messages": ANY
|
||||
},
|
||||
timeout=ollama_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_generation_webui_agent_fixture(config_entry, home_assistant_mock):
|
||||
with patch.object(TextGenerationWebuiAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
|
||||
agent_obj = TextGenerationWebuiAgent(
|
||||
home_assistant_mock,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
text_generation_webui_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
text_generation_webui_agent, all_mocks = text_generation_webui_agent_fixture
|
||||
|
||||
all_mocks["requests_get"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/internal/model/info",
|
||||
headers={ "Authorization": "Bearer Text-Gen-Webui-Admin-Key" }
|
||||
)
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": json.dumps({
|
||||
"to_say": "I am saying something!",
|
||||
"service": "light.turn_on",
|
||||
"target_device": "light.kitchen_light",
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Preset"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"preset": "Some Preset",
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# change options
|
||||
text_generation_webui_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
"object": "chat.completions",
|
||||
"created": 1677652288,
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": json.dumps({
|
||||
"to_say": "I am saying something!",
|
||||
"service": "light.turn_on",
|
||||
"target_device": "light.kitchen_light",
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Character"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"character": "Some Character",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE] = TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT
|
||||
|
||||
# do another turn of the same conversation and use instruct mode
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"instruction_template": "chatml",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def generic_openai_agent_fixture(config_entry, home_assistant_mock):
|
||||
with patch.object(GenericOpenAIAPIAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
"light.kitchen_light": { "state": "on" },
|
||||
"light.office_lamp": { "state": "on" },
|
||||
"switch.downstairs_hallway": { "state": "off" },
|
||||
"fan.bedroom": { "state": "on" },
|
||||
},
|
||||
["light", "switch", "fan"]
|
||||
)
|
||||
|
||||
agent_obj = GenericOpenAIAPIAgent(
|
||||
home_assistant_mock,
|
||||
config_entry
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
|
||||
async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
|
||||
generic_openai_agent: TextGenerationWebuiAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
generic_openai_agent, all_mocks = generic_openai_agent_fixture
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = {
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"text": json.dumps({
|
||||
"to_say": "I am saying something!",
|
||||
"service": "light.turn_on",
|
||||
"target_device": "light.kitchen_light",
|
||||
}),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
all_mocks["requests_post"].return_value = response_mock
|
||||
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"prompt": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
# change options
|
||||
generic_openai_agent.entry.options[CONF_MAX_TOKENS] = 10
|
||||
generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT] = 60
|
||||
generic_openai_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
|
||||
generic_openai_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
generic_openai_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
|
||||
response_mock.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": json.dumps({
|
||||
"to_say": "I am saying something!",
|
||||
"service": "light.turn_on",
|
||||
"target_device": "light.kitchen_light",
|
||||
}),
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": generic_openai_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": generic_openai_agent.entry.options[CONF_TOP_P],
|
||||
"temperature": generic_openai_agent.entry.options[CONF_TEMPERATURE],
|
||||
"max_tokens": generic_openai_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=generic_openai_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
Reference in New Issue
Block a user