fix migration + config sub-entry setup

This commit is contained in:
Alex O'Connell
2025-09-17 18:24:08 -04:00
parent bfc6a5a753
commit b02118705d
10 changed files with 107 additions and 57 deletions

View File

@@ -21,6 +21,7 @@ from .const import (
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
CONF_INSTALLED_LLAMACPP_VERSION,
CONF_SELECTED_LANGUAGE,
CONF_OPENAI_API_KEY,
CONF_GENERIC_OPENAI_PATH,
@@ -41,12 +42,16 @@ from .const import (
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA,
BACKEND_TYPE_LLAMA_EXISTING_OLD,
BACKEND_TYPE_LLAMA_HF_OLD,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
)
from .entity import LocalLLMClient, LocalLLMConfigEntry
from .backends.llamacpp import LlamaCppClient
from .backends.generic_openai import GenericOpenAIAPIClient, GenericOpenAIResponsesAPIClient
from .backends.tailored_openai import TextGenerationWebuiClient, LlamaCppServerClient
from .backends.ollama import OllamaAPIClient
from .utils import get_llama_cpp_python_version
_LOGGER = logging.getLogger(__name__)
@@ -63,14 +68,6 @@ BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
BACKEND_TYPE_OLLAMA: OllamaAPIClient,
}
async def update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry):
"""Handle options update."""
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
client: LocalLLMClient = entry.runtime_data
await hass.async_add_executor_job(client._update_options, dict(entry.options))
async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
# make sure the API is registered
@@ -87,14 +84,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
entry.runtime_data = await hass.async_add_executor_job(create_client, backend_type)
# handle updates to the options
entry.async_on_unload(entry.add_update_listener(update_listener))
# forward setup to platform to register the entity
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(_async_update_listener))
return True
async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> None:
await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
"""Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
@@ -114,8 +114,8 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
# Migrate the config_entry to be an entry + sub-entry
if config_entry.version == 2:
ENTRY_DATA_KEYS = [CONF_BACKEND_TYPE, CONF_SELECTED_LANGUAGE]
ENTRY_OPTIONS_KEYS = [CONF_HOST, CONF_PORT, CONF_SSL, CONF_OPENAI_API_KEY, CONF_GENERIC_OPENAI_PATH]
ENTRY_DATA_KEYS = [CONF_BACKEND_TYPE]
ENTRY_OPTIONS_KEYS = [CONF_SELECTED_LANGUAGE, CONF_HOST, CONF_PORT, CONF_SSL, CONF_OPENAI_API_KEY, CONF_GENERIC_OPENAI_PATH]
SUBENTRY_KEYS = [
CONF_CHAT_MODEL, CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_FILE, CONF_REQUEST_TIMEOUT, CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
@@ -137,29 +137,29 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
entry_options = { k: v for k, v in source_data.items() if k in ENTRY_OPTIONS_KEYS }
subentry_data = { k: v for k, v in source_data.items() if k in SUBENTRY_KEYS }
backend = config_entry.data[CONF_BACKEND_TYPE]
if backend == BACKEND_TYPE_LLAMA_EXISTING_OLD or backend == BACKEND_TYPE_LLAMA_HF_OLD:
backend = BACKEND_TYPE_LLAMA_CPP
entry_data[CONF_BACKEND_TYPE] = BACKEND_TYPE_LLAMA_CPP
entry_options[CONF_INSTALLED_LLAMACPP_VERSION] = await hass.async_add_executor_job(get_llama_cpp_python_version) or EMBEDDED_LLAMA_CPP_PYTHON_VERSION
else:
# ensure all remote backends have a path set
entry_options[CONF_GENERIC_OPENAI_PATH] = entry_options.get(CONF_GENERIC_OPENAI_PATH, "")
entry_title = BACKEND_TO_CLS[backend].get_name(entry_options)
subentry = ConfigSubentry(
data=MappingProxyType(subentry_data),
subentry_type="conversation",
title=config_entry.title, # FIXME: should be the "new" name format
title=config_entry.title.split("'")[-2],
unique_id=None,
)
# don't attempt to create duplicate llama.cpp backends
if entry_data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
all_entries = hass.config_entries.async_entries(DOMAIN)
for entry in all_entries:
if entry.version < 3 or entry.data[CONF_BACKEND_TYPE] != BACKEND_TYPE_LLAMA_CPP:
continue
await hass.config_entries.async_remove(config_entry.entry_id)
config_entry = entry
break
# create sub-entry
hass.config_entries.async_add_subentry(config_entry, subentry)
# update the parent entry
hass.config_entries.async_update_entry(config_entry, data=entry_data, options=entry_options, version=3)
hass.config_entries.async_update_entry(config_entry, title=entry_title, data=entry_data, options=entry_options, version=3)
_LOGGER.debug("Migration to subentries complete")

View File

@@ -59,6 +59,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
)
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
return f"Generic OpenAI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
@staticmethod
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
@@ -233,6 +241,14 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
return f"Generic OpenAI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
def _responses_params(self, conversation: List[conversation.Content], entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}

View File

@@ -20,6 +20,7 @@ from homeassistant.helpers.event import async_track_state_change, async_call_lat
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
from custom_components.llama_conversation.const import (
CONF_INSTALLED_LLAMACPP_VERSION,
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
@@ -112,6 +113,10 @@ class LlamaCppClient(LocalLLMClient):
self.cache_refresh_after_cooldown = False
self.model_lock = threading.Lock()
@staticmethod
def get_name(client_options: dict[str, Any]):
return f"LLama.cpp (llama-cpp-python v{client_options[CONF_INSTALLED_LLAMACPP_VERSION]})"
async def async_get_available_models(self) -> List[str]:
return [] # TODO: find available "huggingface_hub" models that have been downloaded

View File

@@ -62,6 +62,14 @@ class OllamaAPIClient(LocalLLMClient):
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
@staticmethod
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
headers = {}

View File

@@ -5,6 +5,7 @@ import logging
import os
from typing import Optional, Tuple, Dict, List, Any
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
@@ -20,6 +21,7 @@ from custom_components.llama_conversation.const import (
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_CONTEXT_LENGTH,
CONF_GENERIC_OPENAI_PATH,
DEFAULT_TOP_K,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
@@ -32,6 +34,7 @@ from custom_components.llama_conversation.const import (
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
)
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIClient
from custom_components.llama_conversation.utils import format_url
_LOGGER = logging.getLogger(__name__)
@@ -43,6 +46,14 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
self.admin_key = client_options.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
return f"Text-Gen WebUI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
async def _async_load_model(self, entity_options: dict[str, Any]) -> None:
model_name = entity_options.get(CONF_CHAT_MODEL)
try:
@@ -107,6 +118,14 @@ class LlamaCppServerClient(GenericOpenAIAPIClient):
grammar_file_name = client_options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), grammar_file_name)) as f:
self.grammar = "".join(f.readlines())
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
return f"LLama.cpp Server at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
def _chat_completion_params(self, entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
top_k = int(entity_options.get(CONF_TOP_K, DEFAULT_TOP_K))

View File

@@ -52,6 +52,7 @@ from .const import (
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_INSTALLED_LLAMACPP_VERSION,
CONF_SELECTED_LANGUAGE,
CONF_SELECTED_LANGUAGE_OPTIONS,
CONF_DOWNLOADED_MODEL_FILE,
@@ -221,7 +222,6 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
install_wheel_task = None
install_wheel_error = None
installed_version = None
client_config: dict[str, Any]
internal_step: str = "init"
@@ -257,9 +257,10 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
backend = user_input[CONF_BACKEND_TYPE]
self.client_config.update(user_input)
if backend == BACKEND_TYPE_LLAMA_CPP:
self.installed_version = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
_LOGGER.debug(f"installed version: {self.installed_version}")
if self.installed_version == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
installed_version = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
_LOGGER.debug(f"installed version: {installed_version}")
if installed_version == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = installed_version
return await self.async_step_finish()
else:
self.internal_step = "install_local_wheels"
@@ -315,6 +316,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
else:
_LOGGER.debug(f"Finished install: {wheel_install_result}")
next_step = "finish"
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
self.install_wheel_task = None
self.internal_step = next_step
@@ -360,23 +362,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
backend = self.client_config[CONF_BACKEND_TYPE]
title = "Generic AI Provider"
if backend == BACKEND_TYPE_LLAMA_CPP:
title = f"LLama.cpp (llama-cpp-python v{self.installed_version})"
else:
host = self.client_config[CONF_HOST]
port = self.client_config[CONF_PORT]
ssl = self.client_config[CONF_SSL]
path = "/" + self.client_config[CONF_GENERIC_OPENAI_PATH]
if backend == BACKEND_TYPE_GENERIC_OPENAI or backend == BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
title = f"Generic OpenAI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
elif backend == BACKEND_TYPE_TEXT_GEN_WEBUI:
title = f"Text-Gen WebUI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
elif backend == BACKEND_TYPE_OLLAMA:
title = f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
elif backend == BACKEND_TYPE_LLAMA_CPP_SERVER:
title = f"LLama.cpp Server at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
title = BACKEND_TO_CLS[backend].get_name(self.client_config)
_LOGGER.debug(f"creating provider with config: {self.client_config}")
# block duplicate providers
@@ -394,7 +380,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
@classmethod
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
return config_entry.options.get(CONF_BACKEND_TYPE) != BACKEND_TYPE_LLAMA_CPP
return config_entry.data[CONF_BACKEND_TYPE] != BACKEND_TYPE_LLAMA_CPP
@staticmethod
def async_get_options_flow(
@@ -1036,7 +1022,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
errors = {}
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.options[CONF_BACKEND_TYPE]
backend_type = entry.data[CONF_BACKEND_TYPE]
if not self.model_config:
# determine selected language from model config or parent options
@@ -1061,7 +1047,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
schema = vol.Schema(
local_llama_config_option_schema(
self.hass,
entry.options[CONF_SELECTED_LANGUAGE],
entry.options.get(CONF_SELECTED_LANGUAGE, "en"),
self.model_config,
backend_type,
self._subentry_type,

View File

@@ -113,6 +113,7 @@ BACKEND_TYPE_GENERIC_OPENAI_RESPONSES = "generic_openai_responses"
BACKEND_TYPE_LLAMA_CPP_SERVER = "llama_cpp_server"
BACKEND_TYPE_OLLAMA = "ollama"
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_CPP
CONF_INSTALLED_LLAMACPP_VERSION = "installed_llama_cpp_version"
CONF_SELECTED_LANGUAGE = "selected_language"
CONF_SELECTED_LANGUAGE_OPTIONS = [ "en", "de", "fr", "es", "pl"]
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"

View File

@@ -33,7 +33,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
await entry.runtime_data._async_load_model(dict(subentry.data))
# register the agent entity
async_add_entities([agent_entity])
async_add_entities([agent_entity], config_subentry_id=subentry.subentry_id,)
return True

View File

@@ -18,11 +18,11 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError, HomeAssistantError
from homeassistant.helpers import intent, template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr, entity
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util import color
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema
from .const import (
CONF_CHAT_MODEL,
CONF_SELECTED_LANGUAGE,
CONF_PROMPT,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
@@ -84,6 +84,10 @@ class LocalLLMClient:
if icl_examples_filename:
self._load_icl_examples(icl_examples_filename)
@staticmethod
def get_name(client_options: dict[str, Any]):
raise NotImplementedError()
def _load_icl_examples(self, filename: str):
"""Load info used for generating in context learning examples"""
try:
@@ -635,15 +639,26 @@ class LocalLLMEntity(entity.Entity):
"""Initialize the agent."""
self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
model=subentry.data.get(CONF_CHAT_MODEL),
entry_type=dr.DeviceEntryType.SERVICE,
)
self.hass = hass
self.entry_id = entry.entry_id
self.subentry_id = subentry.subentry_id
self.client = client
# create update handler
self.async_on_remove(entry.add_update_listener(self._async_update_options))
def handle_reload(self):
self.client._update_options(self.runtime_options)
async def _async_update_options(self, hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
for subentry in config_entry.subentries.values():
# handle subentry updates, but only invoke for this entity
if subentry.subentry_id == self.subentry_id:
await hass.async_add_executor_job(self.client._update_options, self.runtime_options)
@property
def entry(self) -> ConfigEntry:
@@ -651,7 +666,7 @@ class LocalLLMEntity(entity.Entity):
return self.hass.data[DOMAIN][self.entry_id]
except KeyError as ex:
raise Exception("Attempted to use self.entry during startup.") from ex
@property
def subentry(self) -> ConfigSubentry:
try:

View File

@@ -226,12 +226,12 @@ def install_llama_cpp_python(config_dir: str):
_LOGGER.info("Installing llama-cpp-python from local wheel")
_LOGGER.debug(f"Wheel location: {latest_wheel}")
return install_package(os.path.join(folder, latest_wheel), pip_kwargs(config_dir))
return install_package(os.path.join(folder, latest_wheel), **pip_kwargs(config_dir))
# scikit-build-core v0.9.7+ doesn't recognize these builds as musllinux, and just tags them as generic linux
# github_release_url = f"https://github.com/acon96/home-llm/releases/download/v{INTEGRATION_VERSION}/llama_cpp_python-{EMBEDDED_LLAMA_CPP_PYTHON_VERSION}+homellm{instruction_extensions_suffix}-{runtime_version}-{runtime_version}-musllinux_1_2_{platform_suffix}.whl"
github_release_url = f"https://github.com/acon96/home-llm/releases/download/v{INTEGRATION_VERSION}/llama_cpp_python-{EMBEDDED_LLAMA_CPP_PYTHON_VERSION}+homellm{instruction_extensions_suffix}-{runtime_version}-{runtime_version}-linux_{platform_suffix}.whl"
if install_package(github_release_url, pip_kwargs(config_dir)):
if install_package(github_release_url, **pip_kwargs(config_dir)):
_LOGGER.info("llama-cpp-python successfully installed from GitHub release")
return True