diff --git a/TODO.md b/TODO.md index 596ef20..dd1e0da 100644 --- a/TODO.md +++ b/TODO.md @@ -45,12 +45,11 @@ ## v0.4 TODO for release: -- [ ] re-order the settings on the options config flow page. the order is very confusing +- [x] re-order the settings on the options config flow page. the order is very confusing - [ ] split out entity functionality so we can support conversation + ai tasks - [x] fix icl examples to match new tool calling syntax config - [x] set up docker-compose for running all of the various backends - [ ] fix and re-upload all compatible old models (+ upload all original safetensors) -- [ ] dedicated localai backend (tailored openai variant /w model loading) - [x] fix the openai responses backend ## more complicated ideas diff --git a/custom_components/llama_conversation/backends/llamacpp.py b/custom_components/llama_conversation/backends/llamacpp.py index b463ff3..37c8687 100644 --- a/custom_components/llama_conversation/backends/llamacpp.py +++ b/custom_components/llama_conversation/backends/llamacpp.py @@ -6,7 +6,7 @@ import logging import os import threading import time -from typing import Any, Callable, List, Generator, AsyncGenerator, Optional +from typing import Any, Callable, List, Generator, AsyncGenerator, Optional, cast from homeassistant.components import conversation as conversation from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN @@ -28,15 +28,15 @@ from custom_components.llama_conversation.const import ( CONF_TYPICAL_P, CONF_MIN_P, CONF_DOWNLOADED_MODEL_FILE, - CONF_ENABLE_FLASH_ATTENTION, + CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL, CONF_CONTEXT_LENGTH, - CONF_BATCH_SIZE, - CONF_THREAD_COUNT, - CONF_BATCH_THREAD_COUNT, + CONF_LLAMACPP_BATCH_SIZE, + CONF_LLAMACPP_THREAD_COUNT, + CONF_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, DEFAULT_TEMPERATURE, @@ -44,15 +44,16 @@ from custom_components.llama_conversation.const import ( DEFAULT_TOP_P, DEFAULT_MIN_P, DEFAULT_TYPICAL_P, - DEFAULT_ENABLE_FLASH_ATTENTION, + DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_INTERVAL, DEFAULT_CONTEXT_LENGTH, - DEFAULT_BATCH_SIZE, - DEFAULT_THREAD_COUNT, - DEFAULT_BATCH_THREAD_COUNT, + DEFAULT_LLAMACPP_BATCH_SIZE, + DEFAULT_LLAMACPP_THREAD_COUNT, + DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, + DOMAIN, ) from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult @@ -71,7 +72,7 @@ class LlamaCppAgent(LocalLLMAgent): llm: LlamaType grammar: Any llama_cpp_module: Any - remove_prompt_caching_listener: Callable + remove_prompt_caching_listener: Optional[Callable] model_lock: threading.Lock last_cache_prime: float last_updated_entities: dict[str, float] @@ -81,7 +82,7 @@ class LlamaCppAgent(LocalLLMAgent): _attr_supports_streaming = True def _load_model(self, entry: ConfigEntry) -> None: - self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE) + self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE, "") _LOGGER.info( "Using model file '%s'", self.model_path @@ -109,18 +110,18 @@ class LlamaCppAgent(LocalLLMAgent): _LOGGER.debug(f"Loading model '{self.model_path}'...") self.loaded_model_settings = {} self.loaded_model_settings[CONF_CONTEXT_LENGTH] = entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) - self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) - self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) - self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) - self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION) + self.loaded_model_settings[CONF_LLAMACPP_BATCH_SIZE] = entry.options.get(CONF_LLAMACPP_BATCH_SIZE, DEFAULT_LLAMACPP_BATCH_SIZE) + self.loaded_model_settings[CONF_LLAMACPP_THREAD_COUNT] = entry.options.get(CONF_LLAMACPP_THREAD_COUNT, DEFAULT_LLAMACPP_THREAD_COUNT) + self.loaded_model_settings[CONF_LLAMACPP_BATCH_THREAD_COUNT] = entry.options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_LLAMACPP_BATCH_THREAD_COUNT) + self.loaded_model_settings[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION) self.llm = Llama( model_path=self.model_path, n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]), - n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]), - n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]), - n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]), - flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION], + n_batch=int(self.loaded_model_settings[CONF_LLAMACPP_BATCH_SIZE]), + n_threads=int(self.loaded_model_settings[CONF_LLAMACPP_THREAD_COUNT]), + n_threads_batch=int(self.loaded_model_settings[CONF_LLAMACPP_BATCH_THREAD_COUNT]), + flash_attn=self.loaded_model_settings[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION], ) _LOGGER.debug("Model loaded") @@ -136,7 +137,7 @@ class LlamaCppAgent(LocalLLMAgent): # )) self.remove_prompt_caching_listener = None - self.last_cache_prime = None + self.last_cache_prime = 0.0 self.last_updated_entities = {} self.cache_refresh_after_cooldown = False self.model_lock = threading.Lock() @@ -167,26 +168,26 @@ class LlamaCppAgent(LocalLLMAgent): model_reloaded = False if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \ - self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \ - self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \ - self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \ - self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION): + self.loaded_model_settings[CONF_LLAMACPP_BATCH_SIZE] != self.entry.options.get(CONF_LLAMACPP_BATCH_SIZE, DEFAULT_LLAMACPP_BATCH_SIZE) or \ + self.loaded_model_settings[CONF_LLAMACPP_THREAD_COUNT] != self.entry.options.get(CONF_LLAMACPP_THREAD_COUNT, DEFAULT_LLAMACPP_THREAD_COUNT) or \ + self.loaded_model_settings[CONF_LLAMACPP_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_LLAMACPP_BATCH_THREAD_COUNT) or \ + self.loaded_model_settings[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION): _LOGGER.debug(f"Reloading model '{self.model_path}'...") self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) - self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) - self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) - self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) - self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION) + self.loaded_model_settings[CONF_LLAMACPP_BATCH_SIZE] = self.entry.options.get(CONF_LLAMACPP_BATCH_SIZE, DEFAULT_LLAMACPP_BATCH_SIZE) + self.loaded_model_settings[CONF_LLAMACPP_THREAD_COUNT] = self.entry.options.get(CONF_LLAMACPP_THREAD_COUNT, DEFAULT_LLAMACPP_THREAD_COUNT) + self.loaded_model_settings[CONF_LLAMACPP_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_LLAMACPP_BATCH_THREAD_COUNT) + self.loaded_model_settings[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION) Llama = getattr(self.llama_cpp_module, "Llama") self.llm = Llama( model_path=self.model_path, n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]), - n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]), - n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]), - n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]), - flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION], + n_batch=int(self.loaded_model_settings[CONF_LLAMACPP_BATCH_SIZE]), + n_threads=int(self.loaded_model_settings[CONF_LLAMACPP_THREAD_COUNT]), + n_threads_batch=int(self.loaded_model_settings[CONF_LLAMACPP_BATCH_THREAD_COUNT]), + flash_attn=self.loaded_model_settings[CONF_LLAMACPP_ENABLE_FLASH_ATTENTION], ) _LOGGER.debug("Model loaded") model_reloaded = True @@ -211,7 +212,7 @@ class LlamaCppAgent(LocalLLMAgent): else: self._set_prompt_caching(enabled=False) - def _async_get_exposed_entities(self) -> dict[str, str]: + def _async_get_exposed_entities(self) -> dict[str, dict[str, str]]: """Takes the super class function results and sorts the entities with the recently updated at the end""" entities = LocalLLMAgent._async_get_exposed_entities(self) @@ -219,7 +220,7 @@ class LlamaCppAgent(LocalLLMAgent): if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED): return entities - entity_order = { name: None for name in entities.keys() } + entity_order: dict[str, Optional[float]] = { name: None for name in entities.keys() } entity_order.update(self.last_updated_entities) def sort_key(item): @@ -235,7 +236,7 @@ class LlamaCppAgent(LocalLLMAgent): _LOGGER.debug(f"sorted_items: {sorted_items}") - sorted_entities = {} + sorted_entities: dict[str, dict[str, str]] = {} for item_name, _ in sorted_items: sorted_entities[item_name] = entities[item_name] @@ -271,6 +272,7 @@ class LlamaCppAgent(LocalLLMAgent): try: llm_api = await llm.async_get_api( self.hass, self.entry.options[CONF_LLM_HASS_API], + llm_context=llm.LLMContext(DOMAIN, context=None, language=None, assistant=None, device_id=None) ) except HomeAssistantError: _LOGGER.exception("Failed to get LLM API when caching prompt!") @@ -301,33 +303,48 @@ class LlamaCppAgent(LocalLLMAgent): return try: + # Build system/user messages and use the chat-completion API to prime + # the model. We request only a single token (max_tokens=1) and + # discard the result. This avoids implementing any streaming logic + # while still priming the KV cache with the system prompt. raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - prompt = self._format_prompt([ - { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, - { "role": "user", "message": "" } - ], include_generation_prompt=False) + system_prompt = self._generate_system_prompt(raw_prompt, llm_api) - input_tokens = self.llm.tokenize( - prompt.encode(), add_bos=False - ) + messages = get_oai_formatted_messages([ + conversation.SystemContent(content=system_prompt), + conversation.UserContent(content="") + ]) + tools = None + if llm_api: + tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains()) temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + min_p = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P) + typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None - _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...") + _LOGGER.debug("Priming model cache via chat completion API...") - # grab just one token. should prime the kv cache with the system prompt - next(self.llm.generate( - input_tokens, - temp=temperature, - top_k=top_k, - top_p=top_p, - grammar=grammar - )) + try: + # avoid strict typing issues from the llama-cpp-python bindings + self.llm.create_chat_completion( + messages, + tools=tools, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + max_tokens=1, + grammar=grammar, + stream=False, + ) - self.last_cache_prime = time.time() + self.last_cache_prime = time.time() + except Exception: + _LOGGER.exception("Failed to prime model cache") finally: self.model_lock.release() diff --git a/custom_components/llama_conversation/backends/tailored_openai.py b/custom_components/llama_conversation/backends/tailored_openai.py index c7517b2..6372d81 100644 --- a/custom_components/llama_conversation/backends/tailored_openai.py +++ b/custom_components/llama_conversation/backends/tailored_openai.py @@ -121,4 +121,4 @@ class LlamaCppServerAgent(GenericOpenAIAPIAgent): if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR): request_params["grammar"] = self.grammar - return endpoint, request_params \ No newline at end of file + return endpoint, request_params diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index 2548783..e4b0c1c 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -57,7 +57,7 @@ from .const import ( CONF_TOOL_CALL_PREFIX, CONF_TOOL_CALL_SUFFIX, CONF_ENABLE_LEGACY_TOOL_CALLING, - CONF_ENABLE_FLASH_ATTENTION, + CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, @@ -80,9 +80,9 @@ from .const import ( CONF_GENERIC_OPENAI_PATH, CONF_GENERIC_OPENAI_VALIDATE_MODEL, CONF_CONTEXT_LENGTH, - CONF_BATCH_SIZE, - CONF_THREAD_COUNT, - CONF_BATCH_THREAD_COUNT, + CONF_LLAMACPP_BATCH_SIZE, + CONF_LLAMACPP_THREAD_COUNT, + CONF_LLAMACPP_BATCH_THREAD_COUNT, DEFAULT_CHAT_MODEL, DEFAULT_PORT, DEFAULT_SSL, @@ -107,7 +107,7 @@ from .const import ( DEFAULT_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_SUFFIX, DEFAULT_ENABLE_LEGACY_TOOL_CALLING, - DEFAULT_ENABLE_FLASH_ATTENTION, + DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, @@ -126,9 +126,9 @@ from .const import ( DEFAULT_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL, DEFAULT_CONTEXT_LENGTH, - DEFAULT_BATCH_SIZE, - DEFAULT_THREAD_COUNT, - DEFAULT_BATCH_THREAD_COUNT, + DEFAULT_LLAMACPP_BATCH_SIZE, + DEFAULT_LLAMACPP_THREAD_COUNT, + DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_TEXT_GEN_WEBUI, @@ -882,7 +882,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT CONF_MAX_TOKENS, description={"suggested_value": options.get(CONF_MAX_TOKENS)}, default=DEFAULT_MAX_TOKENS, - ): int, + ): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)), vol.Required( CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)}, @@ -926,7 +926,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT } if is_local_backend(backend_type): - result = insert_after_key(result, CONF_MAX_TOKENS, { + result.update({ vol.Required( CONF_TOP_K, description={"suggested_value": options.get(CONF_TOP_K)}, @@ -969,24 +969,24 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT default=DEFAULT_CONTEXT_LENGTH, ): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)), vol.Required( - CONF_BATCH_SIZE, - description={"suggested_value": options.get(CONF_BATCH_SIZE)}, - default=DEFAULT_BATCH_SIZE, + CONF_LLAMACPP_BATCH_SIZE, + description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_SIZE)}, + default=DEFAULT_LLAMACPP_BATCH_SIZE, ): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)), vol.Required( - CONF_THREAD_COUNT, - description={"suggested_value": options.get(CONF_THREAD_COUNT)}, - default=DEFAULT_THREAD_COUNT, + CONF_LLAMACPP_THREAD_COUNT, + description={"suggested_value": options.get(CONF_LLAMACPP_THREAD_COUNT)}, + default=DEFAULT_LLAMACPP_THREAD_COUNT, ): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)), vol.Required( - CONF_BATCH_THREAD_COUNT, - description={"suggested_value": options.get(CONF_BATCH_THREAD_COUNT)}, - default=DEFAULT_BATCH_THREAD_COUNT, + CONF_LLAMACPP_BATCH_THREAD_COUNT, + description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT)}, + default=DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, ): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)), vol.Required( - CONF_ENABLE_FLASH_ATTENTION, - description={"suggested_value": options.get(CONF_ENABLE_FLASH_ATTENTION)}, - default=DEFAULT_ENABLE_FLASH_ATTENTION, + CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, + description={"suggested_value": options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION)}, + default=DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION, ): BooleanSelector(BooleanSelectorConfig()), vol.Required( CONF_USE_GBNF_GRAMMAR, @@ -1000,7 +1000,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT ): str }) elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: - result = insert_after_key(result, CONF_MAX_TOKENS, { + result.update({ vol.Required( CONF_CONTEXT_LENGTH, description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)}, @@ -1052,7 +1052,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT )), }) elif backend_type in BACKEND_TYPE_GENERIC_OPENAI: - result = insert_after_key(result, CONF_MAX_TOKENS, { + result.update({ vol.Required( CONF_TEMPERATURE, description={"suggested_value": options.get(CONF_TEMPERATURE)}, @@ -1076,7 +1076,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT }) elif backend_type in BACKEND_TYPE_GENERIC_OPENAI_RESPONSES: del result[CONF_REMEMBER_NUM_INTERACTIONS] - result = insert_after_key(result, CONF_REMEMBER_CONVERSATION, { + result.update({ vol.Required( CONF_REMEMBER_CONVERSATION_TIME_MINUTES, description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES)}, @@ -1101,7 +1101,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT ): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)), }) elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER: - result = insert_after_key(result, CONF_MAX_TOKENS, { + result.update({ vol.Required( CONF_TOP_K, description={"suggested_value": options.get(CONF_TOP_K)}, @@ -1139,7 +1139,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT ): bool, }) elif backend_type == BACKEND_TYPE_OLLAMA: - result = insert_after_key(result, CONF_MAX_TOKENS, { + result.update({ vol.Required( CONF_CONTEXT_LENGTH, description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)}, @@ -1182,4 +1182,53 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT ): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)), }) + # sort the options + global_order = [ + # general + CONF_LLM_HASS_API, + CONF_PROMPT, + CONF_CONTEXT_LENGTH, + CONF_MAX_TOKENS, + CONF_OPENAI_API_KEY, + CONF_REQUEST_TIMEOUT, + # sampling parameters + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + CONF_MIN_P, + CONF_TYPICAL_P, + # tool calling/reasoning + CONF_THINKING_PREFIX, + CONF_THINKING_SUFFIX, + CONF_TOOL_CALL_PREFIX, + CONF_TOOL_CALL_SUFFIX, + CONF_MAX_TOOL_CALL_ITERATIONS, + CONF_ENABLE_LEGACY_TOOL_CALLING, + CONF_USE_GBNF_GRAMMAR, + CONF_GBNF_GRAMMAR_FILE, + # integration specific options + CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, + CONF_REFRESH_SYSTEM_PROMPT, + CONF_REMEMBER_CONVERSATION, + CONF_REMEMBER_NUM_INTERACTIONS, + CONF_REMEMBER_CONVERSATION_TIME_MINUTES, + CONF_PROMPT_CACHING_ENABLED, + CONF_PROMPT_CACHING_INTERVAL, + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, + CONF_IN_CONTEXT_EXAMPLES_FILE, + CONF_NUM_IN_CONTEXT_EXAMPLES, + # backend specific options + CONF_LLAMACPP_BATCH_SIZE, + CONF_LLAMACPP_THREAD_COUNT, + CONF_LLAMACPP_BATCH_THREAD_COUNT, + CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, + CONF_TEXT_GEN_WEBUI_ADMIN_KEY, + CONF_TEXT_GEN_WEBUI_PRESET, + CONF_TEXT_GEN_WEBUI_CHAT_MODE, + CONF_OLLAMA_KEEP_ALIVE_MIN, + CONF_OLLAMA_JSON_MODE, + ] + + result = { k: v for k, v in sorted(result.items(), key=lambda item: global_order.index(item[0]) if item[0] in global_order else 9999) } + return result diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index f00879e..87fecba 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -90,7 +90,7 @@ CONF_CHAT_MODEL = "huggingface_model" DEFAULT_CHAT_MODEL = "acon96/Home-3B-v3-GGUF" RECOMMENDED_CHAT_MODELS = [ "acon96/Home-3B-v3-GGUF", "acon96/Home-1B-v3-GGUF", "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" ] CONF_MAX_TOKENS = "max_new_tokens" -DEFAULT_MAX_TOKENS = 128 +DEFAULT_MAX_TOKENS = 512 CONF_TOP_K = "top_k" DEFAULT_TOP_K = 40 CONF_TOP_P = "top_p" @@ -139,8 +139,8 @@ CONF_TOOL_CALL_SUFFIX = "tool_call_suffix" DEFAULT_TOOL_CALL_SUFFIX = "" CONF_ENABLE_LEGACY_TOOL_CALLING = "enable_legacy_tool_calling" DEFAULT_ENABLE_LEGACY_TOOL_CALLING = False -CONF_ENABLE_FLASH_ATTENTION = "enable_flash_attention" -DEFAULT_ENABLE_FLASH_ATTENTION = False +CONF_LLAMACPP_ENABLE_FLASH_ATTENTION = "enable_flash_attention" +DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION = False CONF_USE_GBNF_GRAMMAR = "gbnf_grammar" DEFAULT_USE_GBNF_GRAMMAR = False CONF_GBNF_GRAMMAR_FILE = "gbnf_grammar_file" @@ -183,12 +183,12 @@ CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model" DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True CONF_CONTEXT_LENGTH = "context_length" DEFAULT_CONTEXT_LENGTH = 2048 -CONF_BATCH_SIZE = "batch_size" -DEFAULT_BATCH_SIZE = 512 -CONF_THREAD_COUNT = "n_threads" -DEFAULT_THREAD_COUNT = os.cpu_count() -CONF_BATCH_THREAD_COUNT = "n_batch_threads" -DEFAULT_BATCH_THREAD_COUNT = os.cpu_count() +CONF_LLAMACPP_BATCH_SIZE = "batch_size" +DEFAULT_LLAMACPP_BATCH_SIZE = 512 +CONF_LLAMACPP_THREAD_COUNT = "n_threads" +DEFAULT_LLAMACPP_THREAD_COUNT = os.cpu_count() +CONF_LLAMACPP_BATCH_THREAD_COUNT = "n_batch_threads" +DEFAULT_LLAMACPP_BATCH_THREAD_COUNT = os.cpu_count() DEFAULT_OPTIONS = types.MappingProxyType( { @@ -200,7 +200,7 @@ DEFAULT_OPTIONS = types.MappingProxyType( CONF_TYPICAL_P: DEFAULT_TYPICAL_P, CONF_TEMPERATURE: DEFAULT_TEMPERATURE, CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT, - CONF_ENABLE_FLASH_ATTENTION: DEFAULT_ENABLE_FLASH_ATTENTION, + CONF_LLAMACPP_ENABLE_FLASH_ATTENTION: DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT, @@ -210,9 +210,9 @@ DEFAULT_OPTIONS = types.MappingProxyType( CONF_IN_CONTEXT_EXAMPLES_FILE: DEFAULT_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES: DEFAULT_NUM_IN_CONTEXT_EXAMPLES, CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH, - CONF_BATCH_SIZE: DEFAULT_BATCH_SIZE, - CONF_THREAD_COUNT: DEFAULT_THREAD_COUNT, - CONF_BATCH_THREAD_COUNT: DEFAULT_BATCH_THREAD_COUNT, + CONF_LLAMACPP_BATCH_SIZE: DEFAULT_LLAMACPP_BATCH_SIZE, + CONF_LLAMACPP_THREAD_COUNT: DEFAULT_LLAMACPP_THREAD_COUNT, + CONF_LLAMACPP_BATCH_THREAD_COUNT: DEFAULT_LLAMACPP_BATCH_THREAD_COUNT, CONF_PROMPT_CACHING_ENABLED: DEFAULT_PROMPT_CACHING_ENABLED, CONF_OLLAMA_KEEP_ALIVE_MIN: DEFAULT_OLLAMA_KEEP_ALIVE_MIN, CONF_OLLAMA_JSON_MODE: DEFAULT_OLLAMA_JSON_MODE, diff --git a/custom_components/llama_conversation/conversation.py b/custom_components/llama_conversation/conversation.py index a43789d..def0896 100644 --- a/custom_components/llama_conversation/conversation.py +++ b/custom_components/llama_conversation/conversation.py @@ -466,7 +466,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): return list(domains) - def _async_get_exposed_entities(self) -> dict[str, dict]: + def _async_get_exposed_entities(self) -> dict[str, dict[str, Any]]: """Gather exposed entity states""" entity_states: dict[str, dict] = {} entity_registry = er.async_get(self.hass) diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 7fff800..d9c53fb 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -40,7 +40,6 @@ "openai_validate_model": "Validate model exists?", "text_generation_webui_admin_key": "Admin Key", "text_generation_webui_preset": "Generation Preset/Character Name", - "remote_use_chat_endpoint": "Use chat completions endpoint", "text_generation_webui_chat_mode": "Chat Mode", "selected_language": "Model Language" }, @@ -85,7 +84,6 @@ "in_context_examples_file": "In context learning examples CSV filename", "num_in_context_examples": "Number of ICL examples to generate", "text_generation_webui_preset": "Generation Preset/Character Name", - "remote_use_chat_endpoint": "Use chat completions endpoint", "text_generation_webui_chat_mode": "Chat Mode", "prompt_caching": "Enable Prompt Caching", "prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)", @@ -104,7 +102,6 @@ "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", - "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", "gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.", "prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below." @@ -127,10 +124,10 @@ "min_p": "Min P", "typical_p": "Typical P", "request_timeout": "Remote Request Timeout (seconds)", - "ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)", - "ollama_json_mode": "JSON Output Mode", + "ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)", + "ollama_json_mode": "(ollama) JSON Output Mode", "extra_attributes_to_expose": "Additional attribute to expose in the context", - "enable_flash_attention": "Enable Flash Attention", + "enable_flash_attention": "(llama.cpp) Enable Flash Attention", "gbnf_grammar": "Enable GBNF Grammar", "gbnf_grammar_file": "GBNF Grammar Filename", "openai_api_key": "API Key", @@ -142,15 +139,14 @@ "in_context_examples": "Enable in context learning (ICL) examples", "in_context_examples_file": "In context learning examples CSV filename", "num_in_context_examples": "Number of ICL examples to generate", - "text_generation_webui_preset": "Generation Preset/Character Name", - "remote_use_chat_endpoint": "Use chat completions endpoint", - "text_generation_webui_chat_mode": "Chat Mode", + "text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name", + "text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode", "prompt_caching": "Enable Prompt Caching", "prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)", "context_length": "Context Length", - "batch_size": "Batch Size", - "n_threads": "Thread Count", - "n_batch_threads": "Batch Thread Count", + "batch_size": "(llama.cpp) Batch Size", + "n_threads": "(llama.cpp) Thread Count", + "n_batch_threads": "(llama.cpp) Batch Thread Count", "thinking_prefix": "Reasoning Content Prefix", "thinking_suffix": "Reasoning Content Suffix", "tool_call_prefix": "Tool Call Prefix", @@ -162,7 +158,6 @@ "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", - "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", "gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.", "prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below."