fix tests

This commit is contained in:
Alex O'Connell
2024-05-03 22:52:34 -04:00
parent cdd7e8415a
commit 9b9de48ad4
2 changed files with 9 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ from custom_components.llama_conversation.const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -55,6 +56,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -208,6 +210,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
all_mocks["tokenize"].assert_called_once()
@@ -231,6 +234,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
@@ -244,6 +248,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)
# do another turn of the same conversation

View File

@@ -26,6 +26,7 @@ from custom_components.llama_conversation.const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -67,6 +68,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -304,7 +306,7 @@ def test_validate_options_schema():
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
@@ -313,7 +315,7 @@ def test_validate_options_schema():
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching