add min p and typical p samplers

This commit is contained in:
Alex O'Connell
2024-04-10 23:55:01 -04:00
parent 1b22c06215
commit 7262a2057a
6 changed files with 88 additions and 8 deletions

View File

@@ -33,6 +33,8 @@ from .const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
@@ -66,6 +68,8 @@ from .const import (
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
@@ -781,6 +785,8 @@ class LocalLLaMAAgent(LLaMAAgent):
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
min_p = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
_LOGGER.debug(f"Options: {self.entry.options}")
@@ -799,6 +805,8 @@ class LocalLLaMAAgent(LLaMAAgent):
temp=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
grammar=self.grammar
)
@@ -953,6 +961,8 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
@@ -966,6 +976,8 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
@@ -1088,6 +1100,7 @@ class OllamaAPIAgent(LLaMAAgent):
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
@@ -1101,6 +1114,7 @@ class OllamaAPIAgent(LLaMAAgent):
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
}

View File

@@ -41,6 +41,8 @@ from .const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
@@ -79,6 +81,8 @@ from .const import (
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_BACKEND_TYPE,
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
@@ -727,6 +731,16 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_MIN_P,
description={"suggested_value": options.get(CONF_MIN_P)},
default=DEFAULT_MIN_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
@@ -791,6 +805,16 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_MIN_P,
description={"suggested_value": options.get(CONF_MIN_P)},
default=DEFAULT_MIN_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
@@ -899,6 +923,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_OLLAMA_JSON_MODE,
description={"suggested_value": options.get(CONF_OLLAMA_JSON_MODE)},

View File

@@ -33,7 +33,9 @@ DEFAULT_TOP_K = 40
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1
CONF_TYPICAL_P = "typical_p"
DEFAULT_TYPICAL_P = 0.95
DEFAULT_TYPICAL_P = 1.0
CONF_MIN_P = "min_p"
DEFAULT_MIN_P = 0.0
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.1
CONF_REQUEST_TIMEOUT = "request_timeout"
@@ -156,6 +158,7 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_MIN_P: DEFAULT_MIN_P,
CONF_TYPICAL_P: DEFAULT_TYPICAL_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,

View File

@@ -54,6 +54,8 @@
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"min_p": "Min P",
"typical_p": "Typical P",
"request_timeout": "Remote Request Timeout (seconds)",
"ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "JSON Output Mode",
@@ -102,6 +104,8 @@
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"min_p": "Min P",
"typical_p": "Typical P",
"request_timeout": "Remote Request Timeout (seconds)",
"ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "JSON Output Mode",

View File

@@ -12,6 +12,8 @@ from custom_components.llama_conversation.const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
@@ -46,6 +48,8 @@ from custom_components.llama_conversation.const import (
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
@@ -212,6 +216,8 @@ async def test_local_llama_agent(local_llama_agent_fixture):
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
@@ -227,6 +233,8 @@ async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
local_llama_agent.entry.options[CONF_TYPICAL_P] = 0.95
local_llama_agent._update_options()
@@ -251,10 +259,10 @@ async def test_local_llama_agent(local_llama_agent_fixture):
temp=local_llama_agent.entry.options.get(CONF_TEMPERATURE),
top_k=local_llama_agent.entry.options.get(CONF_TOP_K),
top_p=local_llama_agent.entry.options.get(CONF_TOP_P),
typical_p=local_llama_agent.entry.options[CONF_TYPICAL_P],
min_p=local_llama_agent.entry.options[CONF_MIN_P],
grammar=ANY,
)
# TODO: test backends: text-gen-webui, ollama, generic openai
@pytest.fixture
def ollama_agent_fixture(config_entry, home_assistant_mock):
@@ -343,6 +351,7 @@ async def test_ollama_agent(ollama_agent_fixture):
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
@@ -365,6 +374,7 @@ async def test_ollama_agent(ollama_agent_fixture):
ollama_agent.entry.options[CONF_TEMPERATURE] = 2.0
ollama_agent.entry.options[CONF_TOP_K] = 20
ollama_agent.entry.options[CONF_TOP_P] = 0.9
ollama_agent.entry.options[CONF_TYPICAL_P] = 0.5
# do another turn of the same conversation
result = await ollama_agent.async_process(ConversationInput(
@@ -385,6 +395,7 @@ async def test_ollama_agent(ollama_agent_fixture):
"num_ctx": ollama_agent.entry.options[CONF_CONTEXT_LENGTH],
"top_p": ollama_agent.entry.options[CONF_TOP_P],
"top_k": ollama_agent.entry.options[CONF_TOP_K],
"typical_p": ollama_agent.entry.options[CONF_TYPICAL_P],
"temperature": ollama_agent.entry.options[CONF_TEMPERATURE],
"num_predict": ollama_agent.entry.options[CONF_MAX_TOKENS],
},
@@ -482,6 +493,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"prompt": ANY
@@ -510,6 +523,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"preset": "Some Preset",
@@ -525,6 +540,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
text_generation_webui_agent.entry.options[CONF_REMOTE_USE_CHAT_ENDPOINT] = True
text_generation_webui_agent.entry.options[CONF_TEMPERATURE] = 2.0
text_generation_webui_agent.entry.options[CONF_TOP_P] = 0.9
text_generation_webui_agent.entry.options[CONF_MIN_P] = 0.2
text_generation_webui_agent.entry.options[CONF_TYPICAL_P] = 0.95
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = ""
response_mock.json.return_value = {
@@ -572,6 +589,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
@@ -601,6 +620,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
@@ -630,6 +651,8 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],

View File

@@ -18,6 +18,8 @@ from custom_components.llama_conversation.const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
@@ -58,6 +60,8 @@ from custom_components.llama_conversation.const import (
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
@@ -251,6 +255,7 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TYPICAL_P: DEFAULT_MIN_P,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
@@ -285,6 +290,8 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
assert result4["options"] == options_dict
mock_setup_entry.assert_called_once()
# TODO: write tests for configflow setup for llama.cpp (both versions) + text-generation-webui
def test_validate_options_schema():
universal_options = [
@@ -296,7 +303,7 @@ def test_validate_options_schema():
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
@@ -305,7 +312,7 @@ def test_validate_options_schema():
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
@@ -314,7 +321,7 @@ def test_validate_options_schema():
options_ollama = local_llama_config_option_schema(None, BACKEND_TYPE_OLLAMA)
assert set(options_ollama.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_TYPICAL_P, # supports top_k temperature, top_p and typical_p samplers
CONF_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE, # ollama specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
@@ -322,7 +329,7 @@ def test_validate_options_schema():
options_text_gen_webui = local_llama_config_option_schema(None, BACKEND_TYPE_TEXT_GEN_WEBUI)
assert set(options_text_gen_webui.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET, # text-gen-webui specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
@@ -336,7 +343,7 @@ def test_validate_options_schema():
options_llama_cpp_python_server = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER)
assert set(options_llama_cpp_python_server.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports all sampling parameters
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, # supports top_k, temperature, and top p sampling
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_REQUEST_TIMEOUT, # is a remote backend
])