config sub-entry flow is mostly working now

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:49 -04:00
parent bcfa1dca0a
commit 33362bf9b2
7 changed files with 91 additions and 57 deletions

View File

@@ -6,6 +6,7 @@
- [ ] new model based on gemma3 270m
- [ ] support AI task API
- [ ] move llamacpp to a separate process because of all the crashing
- [ ] optional sampling parameters in options panel (don't pass to backend if not set)
- [x] support new LLM APIs
- rewrite how services are called
- handle no API selected

View File

@@ -169,8 +169,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = entity_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/chat/completions"
endpoint = "/chat/completions"
return endpoint, request_params
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None) -> Tuple[Optional[str], Optional[List]]:

View File

@@ -13,7 +13,7 @@ from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_D
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import callback
from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.event import async_track_state_change, async_call_later
@@ -81,7 +81,7 @@ def snapshot_settings(options: dict[str, Any]) -> dict[str, Any]:
class LlamaCppClient(LocalLLMClient):
llama_cpp_module: Any
llama_cpp_module: Any | None
models: dict[str, LlamaType]
grammars: dict[str, Any]
@@ -96,6 +96,20 @@ class LlamaCppClient(LocalLLMClient):
_attr_supports_streaming = True
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options)
self.llama_cpp_module = None
self.models = {}
self.grammars = {}
self.loaded_model_settings = {}
self.remove_prompt_caching_listener = None
self.last_cache_prime = 0.0
self.last_updated_entities = {}
self.cache_refresh_after_cooldown = False
self.model_lock = threading.Lock()
async def async_get_available_models(self) -> List[str]:
return [] # TODO: copy from config_flow.py
@@ -103,6 +117,10 @@ class LlamaCppClient(LocalLLMClient):
model_name = entity_options.get(CONF_CHAT_MODEL, "")
model_path = entity_options.get(CONF_DOWNLOADED_MODEL_FILE, "")
if model_name in self.models:
_LOGGER.info("Model %s is already loaded", model_name)
return
_LOGGER.info("Using model file '%s'", model_path)
if not model_path or not os.path.isfile(model_path):
@@ -144,13 +162,6 @@ class LlamaCppClient(LocalLLMClient):
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
# ))
self.remove_prompt_caching_listener = None
self.last_cache_prime = 0.0
self.last_updated_entities = {}
self.cache_refresh_after_cooldown = False
self.model_lock = threading.Lock()
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
@callback
async def enable_caching_after_startup(_now) -> None:
@@ -222,10 +233,10 @@ class LlamaCppClient(LocalLLMClient):
if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
current_grammar = entity_options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
if model_name not in self.grammars or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
self._load_grammar(model_name, current_grammar)
else:
self.grammar = None
self.grammars[model_name] = None
if entity_options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
self._set_prompt_caching(entity_options, enabled=True)
@@ -342,17 +353,17 @@ class LlamaCppClient(LocalLLMClient):
if llm_api:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
model_name = entity_options.get(CONF_CHAT_MODEL, "")
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(entity_options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
min_p = entity_options.get(CONF_MIN_P, DEFAULT_MIN_P)
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
grammar = self.grammar if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
grammar = self.grammars.get(model_name) if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
_LOGGER.debug("Priming model cache via chat completion API...")
try:
model_name = entity_options.get(CONF_CHAT_MODEL, "")
# avoid strict typing issues from the llama-cpp-python bindings
self.models[model_name].create_chat_completion(
messages,
@@ -406,6 +417,7 @@ class LlamaCppClient(LocalLLMClient):
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
min_p = entity_options.get(CONF_MIN_P, DEFAULT_MIN_P)
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
grammar = self.grammars.get(model_name) if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
_LOGGER.debug(f"Options: {entity_options}")
@@ -442,7 +454,7 @@ class LlamaCppClient(LocalLLMClient):
min_p=min_p,
typical_p=typical_p,
max_tokens=max_tokens,
grammar=self.grammar,
grammar=grammar,
stream=True,
)

View File

@@ -138,8 +138,6 @@ from .const import (
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_HF_SETUP,
BACKEND_TYPE_LLAMA_EXISTING_SETUP,
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
@@ -165,7 +163,7 @@ from .backends.ollama import OllamaAPIClient
_LOGGER = logging.getLogger(__name__)
def is_local_backend(backend):
return backend in [BACKEND_TYPE_LLAMA_EXISTING_SETUP, BACKEND_TYPE_LLAMA_HF_SETUP, BACKEND_TYPE_LLAMA_CPP]
return backend == BACKEND_TYPE_LLAMA_CPP
def pick_backend_schema(backend_type=None, selected_language=None):
return vol.Schema(
@@ -442,23 +440,17 @@ class OptionsFlow(BaseOptionsFlow):
)
def STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file=None):
def STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file=None, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
return vol.Schema(
{
vol.Required(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
}
)
def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
return vol.Schema(
{
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
vol.Optional(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
options=RECOMMENDED_CHAT_MODELS,
custom_value=True,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
vol.Optional(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
vol.Optional(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
}
)
@@ -852,7 +844,7 @@ def local_llama_config_option_schema(
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P
CONF_TYPICAL_P,
CONF_TOP_K,
# tool calling/reasoning
CONF_THINKING_PREFIX,
@@ -927,10 +919,8 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
self.model_config = {}
backend_type = entry.options.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
if backend_type == BACKEND_TYPE_LLAMA_HF_SETUP:
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA()
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING_SETUP:
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA()
if backend_type == BACKEND_TYPE_LLAMA_CPP:
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA()
else:
schema = STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(await entry.runtime_data.async_get_available_models())
@@ -940,7 +930,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
if len(available_quants) == 0:
errors["base"] = "no_supported_ggufs"
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
)
@@ -949,7 +939,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
description_placeholders["missing"] = self.download_error.missing_quant
description_placeholders["available"] = ", ".join(self.download_error.available_quants)
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.download_error.available_quants[0],
available_quantizations=available_quants,
@@ -957,7 +947,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
else:
errors["base"] = "download_failed"
description_placeholders["exception"] = str(self.download_error)
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
)
@@ -966,17 +956,24 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
self.model_config.update(user_input)
if backend_type == BACKEND_TYPE_LLAMA_HF_SETUP:
return await self.async_step_download(entry)
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING_SETUP:
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
if backend_type == BACKEND_TYPE_LLAMA_CPP:
model_file = self.model_config.get(CONF_DOWNLOADED_MODEL_FILE, "")
if not model_file:
model_name = self.model_config.get(CONF_CHAT_MODEL)
if model_name:
return await self.async_step_download(entry)
else:
errors["base"] = "no_model_name_or_file"
if os.path.exists(model_file):
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
self.internal_step = "model_parameters"
return await self.async_step_model_parameters(None, entry)
else:
errors["base"] = "missing_model_file"
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file)
else:
self.internal_step = "model_parameters"
return await self.async_step_model_parameters(None, entry)
return self.async_show_form(
@@ -1017,12 +1014,15 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
_LOGGER.info("Failed to download model: %s", repr(download_exception))
self.download_error = download_exception
self.internal_step = "select_local_model"
self.download_task = None
return self.async_show_progress_done(next_step_id="failed")
else:
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
self.internal_step = "model_parameters"
self.download_task = None
return self.async_show_progress_done(next_step_id="finish")
self.download_task = None
return self.async_show_progress_done(next_step_id="model_parameters")
async def async_step_model_parameters(
self, user_input: dict[str, Any] | None,
@@ -1082,9 +1082,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
try:
# validate input
schema(user_input)
self.model_config.update(user_input)
return await self.async_step_finish()
except Exception:
_LOGGER.exception("An unknown error has occurred!")

View File

@@ -104,8 +104,8 @@ DEFAULT_TEMPERATURE = 0.1
CONF_REQUEST_TIMEOUT = "request_timeout"
DEFAULT_REQUEST_TIMEOUT = 90
CONF_BACKEND_TYPE = "model_backend"
BACKEND_TYPE_LLAMA_HF_SETUP = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING_SETUP = "llama_cpp_existing"
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
BACKEND_TYPE_LLAMA_CPP = "llama_cpp_python"
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"

View File

@@ -1,12 +1,14 @@
"""Defines the various LLM Backend Agents"""
from __future__ import annotations
from typing import Literal
import logging
from homeassistant.components.conversation import ConversationInput, ConversationResult, ConversationEntity
from homeassistant.components.conversation.models import AbstractConversationAgent
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.helpers import chat_session
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
@@ -21,20 +23,31 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
"""Set up Local LLM Conversation from a config entry."""
for subentry in entry.subentries.values():
if subentry.subentry_type != "conversation":
if subentry.subentry_type != conversation.DOMAIN:
continue
# create one agent entity per conversation subentry
agent_entity = LocalLLMAgent(hass, entry, entry.runtime_data, subentry)
agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data)
# make sure model is loaded
await entry.runtime_data._async_load_model(dict(subentry.data))
# register the agent entity
async_add_entities([agent_entity])
return True
class LocalLLMAgent(LocalLLMEntity, ConversationEntity, AbstractConversationAgent):
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent, LocalLLMEntity):
"""Base Local LLM conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry, subentry: ConfigSubentry, client: LocalLLMClient) -> None:
super().__init__(hass, entry, subentry, client)
if subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
@@ -45,6 +58,11 @@ class LocalLLMAgent(LocalLLMEntity, ConversationEntity, AbstractConversationAgen
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: ConversationInput
) -> ConversationResult:

View File

@@ -130,6 +130,16 @@ class LocalLLMClient:
await self.hass.async_add_executor_job(
self._load_model, entity_options
)
def _unload_model(self, entity_options: dict[str, Any]) -> None:
"""Unload the model to free up space on the backend. Implemented by sub-classes"""
pass
async def _async_unload_model(self, entity_options: dict[str, Any]) -> None:
"""Default implementation is to call _unload_model() which probably does blocking stuff"""
await self.hass.async_add_executor_job(
self._unload_model, entity_options
)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator for streaming responses. Subclasses should implement."""
@@ -632,10 +642,6 @@ class LocalLLMEntity(entity.Entity):
self.subentry_id = subentry.subentry_id
self.client = client
if subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
def handle_reload(self):
self.client._update_options(self.runtime_options)