more tests, fix missing default options, and load ICL as utf8

This commit is contained in:
Alex O'Connell
2024-04-08 20:58:07 -04:00
parent c436498c2e
commit 1577950137
6 changed files with 398 additions and 40 deletions

View File

@@ -137,7 +137,7 @@ class LLaMAAgent(AbstractConversationAgent):
try:
icl_filename = os.path.join(os.path.dirname(__file__), filename)
with open(icl_filename) as f:
with open(icl_filename, encoding="utf-8-sig") as f:
self.in_context_examples = list(csv.DictReader(f))
if set(self.in_context_examples[0].keys()) != set(["service", "response" ]):

View File

@@ -162,16 +162,21 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
CONF_BATCH_SIZE: DEFAULT_BATCH_SIZE,
CONF_THREAD_COUNT: DEFAULT_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT: DEFAULT_BATCH_THREAD_COUNT,
CONF_PROMPT_CACHING_ENABLED: DEFAULT_PROMPT_CACHING_ENABLED,
}
)

2
pytest.ini Normal file
View File

@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto

View File

@@ -12,4 +12,5 @@ homeassistant
hassil
home-assistant-intents
pytest
pytest-asynciop
pytest-asyncio
pytest-homeassistant-custom-component

View File

@@ -78,7 +78,6 @@ from custom_components.llama_conversation.const import (
DEFAULT_OPTIONS,
)
import homeassistant.helpers.template
from homeassistant.components.conversation import ConversationInput
from homeassistant.const import (
CONF_HOST,
@@ -102,19 +101,6 @@ class MockConfigEntry:
self.entry_id = entry_id
self.data = WarnDict(data)
self.options = WarnDict(options)
# @pytest.fixture
# def patch_dependency_group1():
# with patch('path.to.dependency1') as mock1, \
# patch('path.to.dependency2') as mock2:
# yield mock1, mock2
# @pytest.fixture
# def patch_dependency_group2():
# with patch('path.to.dependency3') as mock3, \
# patch('path.to.dependency4') as mock4:
# yield mock3, mock4
@pytest.fixture
@@ -195,34 +181,77 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
yield agent_obj, all_mocks
# TODO: test base llama agent (ICL loading other languages)
@pytest.mark.asyncio # This decorator is necessary for pytest to run async test functions
async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent: LocalLLaMAAgent
all_mocks: dict[str, MagicMock]
local_llama_agent, all_mocks = local_llama_agent_fixture
# invoke the conversation agent
conversation_id = "test-conversation"
result = await local_llama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
))
# assert on results + check side effects
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["llama_class"].assert_called_once_with(
model_path=ANY,
n_ctx=ANY,
n_batch=ANY,
n_threads=ANY,
n_threads_batch=ANY,
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
)
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=ANY,
top_k=ANY,
top_p=ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
grammar=ANY,
)
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
# change options then apply them
local_llama_agent.entry.options[CONF_CONTEXT_LENGTH] = 1024
local_llama_agent.entry.options[CONF_BATCH_SIZE] = 1024
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent._update_options()
all_mocks["llama_class"].assert_called_once_with(
model_path=local_llama_agent.entry.data.get(CONF_DOWNLOADED_MODEL_FILE),
n_ctx=local_llama_agent.entry.options.get(CONF_CONTEXT_LENGTH),
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
)
# do another turn of the same conversation
result = await local_llama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["tokenize"].assert_called_once()
all_mocks["generate"].assert_called_once_with(
ANY,
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
grammar=ANY,
)
# TODO: test backends: text-gen-webui, ollama, generic openai

View File

@@ -1,21 +1,342 @@
import pytest
from unittest.mock import patch
from unittest.mock import patch, MagicMock
# @pytest.fixture
# def patch_dependency_group1():
# with patch('path.to.dependency1') as mock1, \
# patch('path.to.dependency2') as mock2:
# yield mock1, mock2
from homeassistant import config_entries, setup
from homeassistant.core import HomeAssistant
from homeassistant.const import (
CONF_HOST,
CONF_PORT,
CONF_SSL,
)
from homeassistant.data_entry_flow import FlowResultType
# @pytest.fixture
# def patch_dependency_group2():
# with patch('path.to.dependency3') as mock3, \
# patch('path.to.dependency4') as mock4:
# yield mock3, mock4
from custom_components.llama_conversation.config_flow import local_llama_config_option_schema, ConfigFlow
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
DOMAIN,
)
# async def test_validate_config_flow_llama_hf(hass: HomeAssistant):
# result = await hass.config_entries.flow.async_init(
# DOMAIN, context={"source": config_entries.SOURCE_USER}
# )
# assert result["type"] == FlowResultType.FORM
# assert result["errors"] is None
def test_validate_setup_schemas():
pass
# result2 = await hass.config_entries.flow.async_configure(
# result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_LLAMA_HF },
# )
# assert result2["type"] == FlowResultType.FORM
# with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry:
# result3 = await hass.config_entries.flow.async_configure(
# result2["flow_id"],
# TEST_DATA,
# )
# await hass.async_block_till_done()
# assert result3["type"] == "create_entry"
# assert result3["title"] == ""
# assert result3["data"] == {
# # ACCOUNT_ID: TEST_DATA["account_id"],
# # CONF_PASSWORD: TEST_DATA["password"],
# # CONNECTION_TYPE: CLOUD,
# }
# assert result3["options"] == {}
# assert len(mock_setup_entry.mock_calls) == 1
@pytest.fixture
def validate_connections_mock():
validate_mock = MagicMock()
with patch.object(ConfigFlow, '_validate_text_generation_webui', new=validate_mock), \
patch.object(ConfigFlow, '_validate_ollama', new=validate_mock):
yield validate_mock
@pytest.fixture
def mock_setup_entry():
with patch("custom_components.llama_conversation.async_setup_entry", return_value=True) as mock_setup_entry, \
patch("custom_components.llama_conversation.async_unload_entry", return_value=True):
yield mock_setup_entry
async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI },
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
)
await hass.async_block_till_done()
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_GENERIC_OPENAI,
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
}
assert result4["options"] == options_dict
assert len(mock_setup_entry.mock_calls) == 1
async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant, enable_custom_integrations, validate_connections_mock):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {}
assert result["step_id"] == "pick_backend"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], { CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA },
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {}
assert result2["step_id"] == "remote_model"
# simulate incorrect settings on first try
validate_connections_mock.side_effect = [
("failed_to_connect", Exception("ConnectionError")),
(None, None)
]
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5000",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert len(result3["errors"]) == 1
assert "base" in result3["errors"]
assert result3["step_id"] == "remote_model"
# retry
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
},
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {}
assert result3["step_id"] == "model_parameters"
options_dict = {
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE,
}
result4 = await hass.config_entries.flow.async_configure(
result2["flow_id"], options_dict
)
await hass.async_block_till_done()
assert result4["type"] == "create_entry"
assert result4["title"] == f"LLM Model '{DEFAULT_CHAT_MODEL}' (remote)"
assert result4["data"] == {
CONF_BACKEND_TYPE: BACKEND_TYPE_OLLAMA,
CONF_HOST: "localhost",
CONF_PORT: "5001",
CONF_SSL: False,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
}
assert result4["options"] == options_dict
mock_setup_entry.assert_called_once()
def test_validate_options_schema():
pass
universal_options = [
CONF_PROMPT, CONF_PROMPT_TEMPLATE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
])
options_ollama = local_llama_config_option_schema(None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_text_gen_webui = local_llama_config_option_schema(None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_generic_openai = local_llama_config_option_schema(None, BACKEND_TYPE_GENERIC_OPENAI)
assert set(options_generic_openai.keys()) == set(universal_options + [
CONF_TEMPERATURE, CONF_TOP_P, # only supports top_p and temperature sampling
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])
options_llama_cpp_python_server = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])