fix config entry setup

This commit is contained in:
Alex O'Connell
2025-04-13 18:00:26 -04:00
parent b91368d17f
commit 811c1ea23b
3 changed files with 48 additions and 40 deletions

View File

@@ -8,13 +8,11 @@ import homeassistant.components.conversation as ha_conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, llm
from homeassistant.util.json import JsonObjectType
import voluptuous as vol
from .const import (
ALLOWED_SERVICE_CALL_ARGUMENTS,
DOMAIN,
@@ -22,7 +20,19 @@ from .const import (
SERVICE_TOOL_NAME,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
)
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, \
LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
_LOGGER = logging.getLogger(__name__)
@@ -31,7 +41,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
# make sure the API is registered
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
@@ -39,18 +49,43 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LlamaCppAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
elif backend_type == BACKEND_TYPE_OLLAMA:
agent_cls = OllamaAPIAgent
return agent_cls(hass, entry)
# create the agent in an executor job because the constructor calls `open()`
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
entry.runtime_data = await hass.async_add_executor_job(create_agent, backend_type)
# call load model
await entry.runtime_data._async_load_model(entry)
# forward setup to platform to register the entity
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
"""Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id)
return True
async def async_migrate_entry(hass, config_entry: ConfigEntry):
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
"""Migrate old entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)

View File

@@ -687,6 +687,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
for key in OPTIONS_OVERRIDES.keys():
if key in model_name:
selected_default_options.update(OPTIONS_OVERRIDES[key])
break
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en"))
@@ -765,15 +766,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
config_entry: config_entries.ConfigEntry,
) -> config_entries.OptionsFlow:
"""Create the options flow."""
return OptionsFlow(config_entry)
return OptionsFlow()
class OptionsFlow(config_entries.OptionsFlow):
"""Local LLM config flow options handler."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
@property
def config_entry(self):
return self.hass.config_entries.async_get_entry(self.handler)
async def async_step_init(
self, user_input: dict[str, Any] | None = None

View File

@@ -125,12 +125,6 @@ from .const import (
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
)
# make type checking work for llama-cpp-python without importing it directly at runtime
@@ -149,7 +143,7 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
agent: LocalLLMAgent = entry.runtime_data
await hass.async_add_executor_job(agent._update_options)
return True
@@ -157,33 +151,11 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback) -> bool:
"""Set up Local LLM Conversation from a config entry."""
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LlamaCppAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
elif backend_type == BACKEND_TYPE_OLLAMA:
agent_cls = OllamaAPIAgent
return agent_cls(hass, entry)
# create the agent in an executor job because the constructor calls `open()`
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
agent = await hass.async_add_executor_job(create_agent, backend_type)
# call load model
await agent._async_load_model(entry)
# handle updates to the options
entry.async_on_unload(entry.add_update_listener(update_listener))
async_add_entities([agent])
# register the agent entity
async_add_entities([entry.runtime_data])
return True