mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
fix tests
This commit is contained in:
@@ -20,6 +20,7 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -55,6 +56,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -208,6 +210,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
all_mocks["tokenize"].assert_called_once()
|
||||
@@ -231,6 +234,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
|
||||
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
|
||||
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
|
||||
local_llama_agent.entry.options[CONF_TOP_K] = 20
|
||||
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
|
||||
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
|
||||
@@ -244,6 +248,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
|
||||
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
|
||||
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
|
||||
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
|
||||
)
|
||||
|
||||
# do another turn of the same conversation
|
||||
|
||||
@@ -26,6 +26,7 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -67,6 +68,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -304,7 +306,7 @@ def test_validate_options_schema():
|
||||
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
|
||||
assert set(options_llama_hf.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
@@ -313,7 +315,7 @@ def test_validate_options_schema():
|
||||
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
|
||||
assert set(options_llama_existing.keys()) == set(universal_options + [
|
||||
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
|
||||
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
|
||||
CONF_CONTEXT_LENGTH, # supports context length
|
||||
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
|
||||
|
||||
Reference in New Issue
Block a user