mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
fix config entry setup
This commit is contained in:
@@ -8,13 +8,11 @@ import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
||||
from .const import (
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DOMAIN,
|
||||
@@ -22,7 +20,19 @@ from .const import (
|
||||
SERVICE_TOOL_NAME,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, \
|
||||
LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
|
||||
|
||||
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +41,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
|
||||
# make sure the API is registered
|
||||
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
|
||||
@@ -39,18 +49,43 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
entry.runtime_data = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await entry.runtime_data._async_load_model(entry)
|
||||
|
||||
# forward setup to platform to register the entity
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return True
|
||||
|
||||
async def async_migrate_entry(hass, config_entry: ConfigEntry):
|
||||
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
|
||||
"""Migrate old entry."""
|
||||
_LOGGER.debug("Migrating from version %s", config_entry.version)
|
||||
|
||||
|
||||
@@ -687,6 +687,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
for key in OPTIONS_OVERRIDES.keys():
|
||||
if key in model_name:
|
||||
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
||||
break
|
||||
|
||||
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
|
||||
current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en"))
|
||||
@@ -765,15 +766,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
) -> config_entries.OptionsFlow:
|
||||
"""Create the options flow."""
|
||||
return OptionsFlow(config_entry)
|
||||
return OptionsFlow()
|
||||
|
||||
|
||||
class OptionsFlow(config_entries.OptionsFlow):
|
||||
"""Local LLM config flow options handler."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
@property
|
||||
def config_entry(self):
|
||||
return self.hass.config_entries.async_get_entry(self.handler)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
||||
@@ -125,12 +125,6 @@ from .const import (
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
@@ -149,7 +143,7 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent: LocalLLMAgent = entry.runtime_data
|
||||
await hass.async_add_executor_job(agent._update_options)
|
||||
|
||||
return True
|
||||
@@ -157,33 +151,11 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback) -> bool:
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
agent = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await agent._async_load_model(entry)
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
async_add_entities([agent])
|
||||
# register the agent entity
|
||||
async_add_entities([entry.runtime_data])
|
||||
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user