Files
home-llm/custom_components/llama_conversation/config_flow.py
2025-12-14 12:35:41 -05:00

1278 lines
55 KiB
Python

"""Config flow for Local LLM Conversation integration."""
from __future__ import annotations
from asyncio import Task
import logging
import os
from typing import Any
import voluptuous as vol
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
from homeassistant.components import conversation, ai_task
from homeassistant.data_entry_flow import (
AbortFlow,
)
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow as BaseConfigFlow,
ConfigFlowResult,
ConfigSubentryFlow,
SubentryFlowResult,
ConfigEntriesFlowManager,
OptionsFlow as BaseOptionsFlow,
ConfigEntryState,
)
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
NumberSelectorMode,
TemplateSelector,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TextSelector,
TextSelectorConfig,
TextSelectorType,
BooleanSelector,
BooleanSelectorConfig,
)
from .utils import download_model_from_hf, get_llama_cpp_python_version, install_llama_cpp_python, \
is_valid_hostname, get_available_llama_cpp_versions, MissingQuantizationException
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
DEFAULT_AI_TASK_PROMPT,
CONF_AI_TASK_RETRIES,
DEFAULT_AI_TASK_RETRIES,
CONF_AI_TASK_EXTRACTION_METHOD,
DEFAULT_AI_TASK_EXTRACTION_METHOD,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_BACKEND_TYPE,
CONF_INSTALLED_LLAMACPP_VERSION,
CONF_SELECTED_LANGUAGE,
CONF_SELECTED_LANGUAGE_OPTIONS,
CONF_DOWNLOADED_MODEL_FILE,
CONF_DOWNLOADED_MODEL_QUANTIZATION,
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS,
CONF_THINKING_PREFIX,
CONF_THINKING_SUFFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOOL_CALL_SUFFIX,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_GENERIC_OPENAI_PATH,
CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_REINSTALL,
DEFAULT_CHAT_MODEL,
DEFAULT_PORT,
DEFAULT_SSL,
DEFAULT_MAX_TOKENS,
PERSONA_PROMPTS,
CURRENT_DATE_PROMPT,
DEVICES_PROMPT,
SERVICES_PROMPT,
TOOLS_PROMPT,
AREA_PROMPT,
USER_INSTRUCTION,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_BACKEND_TYPE,
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
DEFAULT_THINKING_PREFIX,
DEFAULT_THINKING_SUFFIX,
DEFAULT_TOOL_CALL_PREFIX,
DEFAULT_TOOL_CALL_SUFFIX,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_GENERIC_OPENAI_PATH,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_LLAMACPP_BATCH_SIZE,
DEFAULT_LLAMACPP_THREAD_COUNT,
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
HOME_LLM_API_ID,
DEFAULT_OPTIONS,
option_overrides,
RECOMMENDED_CHAT_MODELS,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
)
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
_LOGGER = logging.getLogger(__name__)
def _coerce_int(val, default=0):
try:
return int(val)
except (TypeError, ValueError):
return default
def pick_backend_schema(backend_type=None, selected_language=None):
return vol.Schema(
{
vol.Required(
CONF_BACKEND_TYPE,
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
): SelectSelector(SelectSelectorConfig(
options=[
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
BACKEND_TYPE_LLAMA_CPP_SERVER,
BACKEND_TYPE_OLLAMA
],
translation_key=CONF_BACKEND_TYPE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig(
options=CONF_SELECTED_LANGUAGE_OPTIONS,
translation_key=CONF_SELECTED_LANGUAGE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
}
)
def remote_connection_schema(backend_type: str, *, host=None, port=None, ssl=None, selected_path=None):
extra = {}
default_port = DEFAULT_PORT
default_path = DEFAULT_GENERIC_OPENAI_PATH
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD))
elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER:
default_port = "8000"
elif backend_type == BACKEND_TYPE_OLLAMA:
default_port = "11434"
default_path = ""
elif backend_type in [BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_GENERIC_OPENAI_RESPONSES]:
default_port = ""
return vol.Schema(
{
vol.Required(CONF_HOST, default=host if host else ""): str,
vol.Optional(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD)),
vol.Optional(
CONF_GENERIC_OPENAI_PATH,
default=selected_path if selected_path else default_path
): TextSelector(TextSelectorConfig(prefix="/")),
**extra
}
)
class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Handle a config flow for Local LLM Conversation."""
VERSION = 3
MINOR_VERSION = 2
install_wheel_task = None
install_wheel_error = None
client_config: dict[str, Any]
internal_step: str = "init"
@property
def flow_manager(self) -> ConfigEntriesFlowManager:
"""Return the correct flow manager."""
return self.hass.config_entries.flow
def async_remove(self) -> None:
if self.install_wheel_task:
self.install_wheel_task.cancel()
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
errors = {}
description_placeholders = {}
if self.internal_step == "init":
self.client_config = {}
# make sure the API is registered
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]):
llm.async_register_api(self.hass, HomeLLMAPI(self.hass))
self.internal_step = "pick_backend"
return self.async_show_form(
step_id="user", data_schema=pick_backend_schema(), last_step=False
)
elif self.internal_step == "pick_backend":
if user_input:
backend = user_input[CONF_BACKEND_TYPE]
self.client_config.update(user_input)
if backend == BACKEND_TYPE_LLAMA_CPP:
installed_version = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
_LOGGER.debug(f"installed version: {installed_version}")
if installed_version and installed_version == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = installed_version
return await self.async_step_finish()
else:
self.internal_step = "install_local_wheels"
_LOGGER.debug("Queuing install task...")
async def install_task():
return await self.hass.async_add_executor_job(
install_llama_cpp_python, self.hass.config.config_dir
)
self.install_wheel_task = self.hass.async_create_background_task(
install_task(), name="llama_cpp_python_installation")
return self.async_show_progress(
progress_task=self.install_wheel_task,
step_id="user",
progress_action="install_local_wheels",
)
else:
self.internal_step = "configure_connection"
return self.async_show_form(
step_id="user", data_schema=remote_connection_schema(self.client_config[CONF_BACKEND_TYPE]), last_step=True
)
elif self.install_wheel_error:
errors["base"] = str(self.install_wheel_error)
self.install_wheel_error = None
return self.async_show_form(
step_id="user", data_schema=pick_backend_schema(
backend_type=self.client_config.get(CONF_BACKEND_TYPE),
selected_language=self.client_config.get(CONF_SELECTED_LANGUAGE)
), errors=errors, last_step=False)
elif self.internal_step == "install_local_wheels":
if not self.install_wheel_task:
return self.async_abort(reason="unknown")
if not self.install_wheel_task.done():
return self.async_show_progress(
progress_task=self.install_wheel_task,
step_id="user",
progress_action="install_local_wheels",
)
install_exception = self.install_wheel_task.exception()
if install_exception:
_LOGGER.warning("Failed to install wheel: %s", repr(install_exception))
self.install_wheel_error = "pip_wheel_error"
next_step = "pick_backend"
else:
wheel_install_result = self.install_wheel_task.result()
if not wheel_install_result:
self.install_wheel_error = "pip_wheel_error"
next_step = "pick_backend"
else:
_LOGGER.debug(f"Finished install: {wheel_install_result}")
next_step = "finish"
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
self.install_wheel_task = None
self.internal_step = next_step
return self.async_show_progress_done(next_step_id="finish")
elif self.internal_step == "configure_connection":
if user_input:
self.client_config.update(user_input)
hostname = user_input.get(CONF_HOST, "")
if not is_valid_hostname(hostname):
errors["base"] = "invalid_hostname"
else:
# validate remote connections
connect_err = await BACKEND_TO_CLS[self.client_config[CONF_BACKEND_TYPE]].async_validate_connection(self.hass, self.client_config)
if connect_err:
errors["base"] = "failed_to_connect"
description_placeholders["exception"] = str(connect_err)
else:
return await self.async_step_finish()
return self.async_show_form(
step_id="user",
data_schema=remote_connection_schema(
self.client_config[CONF_BACKEND_TYPE],
host=self.client_config.get(CONF_HOST),
port=self.client_config.get(CONF_PORT),
ssl=self.client_config.get(CONF_SSL),
selected_path=self.client_config.get(CONF_GENERIC_OPENAI_PATH)
),
errors=errors,
description_placeholders=description_placeholders,
last_step=True
)
else:
raise AbortFlow("Unknown internal step")
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
backend = self.client_config[CONF_BACKEND_TYPE]
title = BACKEND_TO_CLS[backend].get_name(self.client_config)
_LOGGER.debug(f"creating provider with config: {self.client_config}")
# block duplicate providers
for entry in self.hass.config_entries.async_entries(DOMAIN):
if backend == BACKEND_TYPE_LLAMA_CPP and \
entry.data.get(CONF_BACKEND_TYPE) == BACKEND_TYPE_LLAMA_CPP:
return self.async_abort(reason="duplicate_client")
return self.async_create_entry(
title=title,
description="A Large Language Model Chat Agent",
data={CONF_BACKEND_TYPE: backend},
options=self.client_config,
)
@classmethod
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
return True
@staticmethod
def async_get_options_flow(
config_entry: ConfigEntry,
) -> BaseOptionsFlow:
"""Create the options flow."""
return OptionsFlow()
@classmethod
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
}
class OptionsFlow(BaseOptionsFlow):
"""Local LLM config flow options handler."""
model_config: dict[str, Any] | None = None
reinstall_task: Task[Any] | None = None
wheel_install_error: str | None = None
wheel_install_successful: bool = False
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
errors = {}
description_placeholders = {}
backend_type = self.config_entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
client_config = dict(self.config_entry.options)
if self.wheel_install_error:
_LOGGER.warning("Failed to install wheel: %s", repr(self.wheel_install_error))
return self.async_abort(reason="pip_wheel_error")
if self.wheel_install_successful:
client_config[CONF_INSTALLED_LLAMACPP_VERSION] = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
_LOGGER.debug(f"new version is: {client_config[CONF_INSTALLED_LLAMACPP_VERSION]}")
return self.async_create_entry(data=client_config)
if backend_type == BACKEND_TYPE_LLAMA_CPP:
potential_versions = await get_available_llama_cpp_versions(self.hass)
schema = vol.Schema({
vol.Required(CONF_LLAMACPP_REINSTALL, default=False): BooleanSelector(BooleanSelectorConfig()),
vol.Required(CONF_INSTALLED_LLAMACPP_VERSION, default=client_config.get(CONF_INSTALLED_LLAMACPP_VERSION, "not installed")): SelectSelector(
SelectSelectorConfig(
options=[ SelectOptionDict(value=x[0], label=x[0] if not x[1] else f"{x[0]} (local)") for x in potential_versions ],
mode=SelectSelectorMode.DROPDOWN,
)
)
})
return self.async_show_form(
step_id="reinstall",
data_schema=schema,
)
else:
if user_input is not None:
client_config.update(user_input)
# validate remote connections
connect_err = await BACKEND_TO_CLS[backend_type].async_validate_connection(self.hass, client_config)
if not connect_err:
return self.async_create_entry(data=client_config)
else:
errors["base"] = "failed_to_connect"
description_placeholders["exception"] = str(connect_err)
schema = remote_connection_schema(
backend_type=backend_type,
host=client_config.get(CONF_HOST),
port=client_config.get(CONF_PORT),
ssl=client_config.get(CONF_SSL),
selected_path=client_config.get(CONF_GENERIC_OPENAI_PATH)
)
return self.async_show_form(
step_id="init",
data_schema=schema,
errors=errors,
description_placeholders=description_placeholders,
)
async def async_step_reinstall(self, user_input: dict[str, Any] | None = None) -> ConfigFlowResult:
client_config = dict(self.config_entry.options)
if user_input is not None:
if not user_input[CONF_LLAMACPP_REINSTALL]:
_LOGGER.debug("Reinstall was not selected, finishing")
return self.async_create_entry(data=client_config)
if not self.reinstall_task:
if not user_input:
return self.async_abort(reason="unknown")
desired_version = user_input.get(CONF_INSTALLED_LLAMACPP_VERSION)
async def install_task():
return await self.hass.async_add_executor_job(
install_llama_cpp_python, self.hass.config.config_dir, True, desired_version
)
self.reinstall_task = self.hass.async_create_background_task(
install_task(), name="llama_cpp_python_installation")
_LOGGER.debug("Queuing reinstall task...")
return self.async_show_progress(
progress_task=self.reinstall_task,
step_id="reinstall",
progress_action="install_local_wheels",
)
if not self.reinstall_task.done():
return self.async_show_progress(
progress_task=self.reinstall_task,
step_id="reinstall",
progress_action="install_local_wheels",
)
_LOGGER.debug("done... checking result")
install_exception = self.reinstall_task.exception()
if install_exception:
self.wheel_install_error = repr(install_exception)
_LOGGER.debug(f"Hit error: {self.wheel_install_error}")
return self.async_show_progress_done(next_step_id="init")
else:
wheel_install_result = self.reinstall_task.result()
if not wheel_install_result:
self.wheel_install_error = "Pip returned false"
_LOGGER.debug(f"Hit error: {self.wheel_install_error} ({wheel_install_result})")
return self.async_show_progress_done(next_step_id="init")
else:
_LOGGER.debug(f"Finished install: {wheel_install_result}")
self.wheel_install_successful = True
return self.async_show_progress_done(next_step_id="init")
def STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file=None, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
return vol.Schema(
{
vol.Optional(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
options=RECOMMENDED_CHAT_MODELS,
custom_value=True,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Optional(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
vol.Optional(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
}
)
def STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(available_models: list[str], chat_model: str | None = None):
_LOGGER.debug(f"available models: {available_models}")
return vol.Schema(
{
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else available_models[0]): SelectSelector(SelectSelectorConfig(
options=available_models,
custom_value=True,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
}
)
def build_prompt_template(selected_language: str, prompt_template_template: str):
persona = PERSONA_PROMPTS.get(selected_language, PERSONA_PROMPTS["en"])
current_date = CURRENT_DATE_PROMPT.get(selected_language, CURRENT_DATE_PROMPT["en"])
devices = DEVICES_PROMPT.get(selected_language, DEVICES_PROMPT["en"])
services = SERVICES_PROMPT.get(selected_language, SERVICES_PROMPT["en"])
tools = TOOLS_PROMPT.get(selected_language, TOOLS_PROMPT["en"])
area = AREA_PROMPT.get(selected_language, AREA_PROMPT["en"])
user_instruction = USER_INSTRUCTION.get(selected_language, USER_INSTRUCTION["en"])
prompt_template_template = prompt_template_template.replace("<persona>", persona)
prompt_template_template = prompt_template_template.replace("<current_date>", current_date)
prompt_template_template = prompt_template_template.replace("<devices>", devices)
prompt_template_template = prompt_template_template.replace("<services>", services)
prompt_template_template = prompt_template_template.replace("<tools>", tools)
prompt_template_template = prompt_template_template.replace("<area>", area)
prompt_template_template = prompt_template_template.replace("<user_instruction>", user_instruction)
return prompt_template_template
def local_llama_config_option_schema(
hass: HomeAssistant,
language: str,
options: dict[str, Any],
backend_type: str,
subentry_type: str,
) -> dict:
result: dict = {
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_THINKING_PREFIX,
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
default=DEFAULT_THINKING_PREFIX,
): str,
vol.Required(
CONF_THINKING_SUFFIX,
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
default=DEFAULT_THINKING_SUFFIX,
): str,
vol.Required(
CONF_TOOL_CALL_PREFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
default=DEFAULT_TOOL_CALL_PREFIX,
): str,
vol.Required(
CONF_TOOL_CALL_SUFFIX,
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
default=DEFAULT_TOOL_CALL_SUFFIX,
): str,
vol.Required(
CONF_ENABLE_LEGACY_TOOL_CALLING,
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
): bool,
}
if subentry_type == ai_task.DOMAIN:
result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)},
default=options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT),
): TemplateSelector(),
vol.Required(
CONF_AI_TASK_EXTRACTION_METHOD,
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
): SelectSelector(SelectSelectorConfig(
options=[
SelectOptionDict(value="none", label="None"),
SelectOptionDict(value="structure", label="Structured output"),
SelectOptionDict(value="tool", label="Tool call"),
],
mode=SelectSelectorMode.DROPDOWN,
)),
vol.Required(
CONF_AI_TASK_RETRIES,
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
})
elif subentry_type == conversation.DOMAIN:
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
]
result.update({
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
default=options.get(CONF_PROMPT, default_prompt),
): TemplateSelector(),
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_IN_CONTEXT_EXAMPLES_FILE,
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
): str,
vol.Required(
CONF_NUM_IN_CONTEXT_EXAMPLES,
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
vol.Required(
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default=None,
): SelectSelector(SelectSelectorConfig(options=apis, multiple=True)),
vol.Optional(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
): BooleanSelector(BooleanSelectorConfig()),
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
})
if backend_type == BACKEND_TYPE_LLAMA_CPP:
if subentry_type == conversation.DOMAIN:
result.update({
vol.Required(
CONF_PROMPT_CACHING_ENABLED,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
default=DEFAULT_PROMPT_CACHING_ENABLED,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_PROMPT_CACHING_INTERVAL,
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
default=DEFAULT_PROMPT_CACHING_INTERVAL,
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
})
result.update({
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_MIN_P,
description={"suggested_value": options.get(CONF_MIN_P)},
default=DEFAULT_MIN_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
# TODO: add rope_scaling_type
vol.Required(
CONF_CONTEXT_LENGTH,
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
default=DEFAULT_CONTEXT_LENGTH,
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
vol.Required(
CONF_LLAMACPP_BATCH_SIZE,
description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_SIZE)},
default=DEFAULT_LLAMACPP_BATCH_SIZE,
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
vol.Required(
CONF_LLAMACPP_THREAD_COUNT,
description={"suggested_value": options.get(CONF_LLAMACPP_THREAD_COUNT)},
default=DEFAULT_LLAMACPP_THREAD_COUNT,
): NumberSelector(NumberSelectorConfig(min=1, max=((os.cpu_count() or 1) * 2), step=1)),
vol.Required(
CONF_LLAMACPP_BATCH_THREAD_COUNT,
description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT)},
default=DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
): NumberSelector(NumberSelectorConfig(min=1, max=((os.cpu_count() or 1) * 2), step=1)),
vol.Required(
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
description={"suggested_value": options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION)},
default=DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
default=DEFAULT_USE_GBNF_GRAMMAR,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_GBNF_GRAMMAR_FILE,
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
default=DEFAULT_GBNF_GRAMMAR_FILE,
): str
})
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
result.update({
vol.Required(
CONF_CONTEXT_LENGTH,
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
default=DEFAULT_CONTEXT_LENGTH,
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_MIN_P,
description={"suggested_value": options.get(CONF_MIN_P)},
default=DEFAULT_MIN_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
vol.Optional(
CONF_TEXT_GEN_WEBUI_PRESET,
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_PRESET)},
): str,
vol.Required(
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE)},
default=DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
): SelectSelector(SelectSelectorConfig(
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
})
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI:
result.update({
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
})
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
del result[CONF_REMEMBER_NUM_INTERACTIONS]
result.update({
vol.Required(
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=180, step=0.5, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
})
elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER:
result.update({
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
default=DEFAULT_USE_GBNF_GRAMMAR,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_GBNF_GRAMMAR_FILE,
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
default=DEFAULT_GBNF_GRAMMAR_FILE,
): str,
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
})
elif backend_type == BACKEND_TYPE_OLLAMA:
result.update({
vol.Required(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
vol.Required(
CONF_CONTEXT_LENGTH,
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
default=DEFAULT_CONTEXT_LENGTH,
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TYPICAL_P,
description={"suggested_value": options.get(CONF_TYPICAL_P)},
default=DEFAULT_TYPICAL_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_OLLAMA_JSON_MODE,
description={"suggested_value": options.get(CONF_OLLAMA_JSON_MODE)},
default=DEFAULT_OLLAMA_JSON_MODE,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
vol.Required(
CONF_OLLAMA_KEEP_ALIVE_MIN,
description={"suggested_value": options.get(CONF_OLLAMA_KEEP_ALIVE_MIN)},
default=DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
})
# sort the options
global_order = [
# general
CONF_LLM_HASS_API,
CONF_PROMPT,
CONF_AI_TASK_EXTRACTION_METHOD,
CONF_AI_TASK_RETRIES,
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
# sampling parameters
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MIN_P,
CONF_TYPICAL_P,
CONF_TOP_K,
# tool calling/reasoning
CONF_THINKING_PREFIX,
CONF_THINKING_SUFFIX,
CONF_TOOL_CALL_PREFIX,
CONF_TOOL_CALL_SUFFIX,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
# integration specific options
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_IN_CONTEXT_EXAMPLES_FILE,
CONF_NUM_IN_CONTEXT_EXAMPLES,
# backend specific options
CONF_LLAMACPP_BATCH_SIZE,
CONF_LLAMACPP_THREAD_COUNT,
CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
]
result = { k: v for k, v in sorted(result.items(), key=lambda item: global_order.index(item[0]) if item[0] in global_order else 9999) }
return result
class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing Local LLM subentries."""
def __init__(self) -> None:
"""Initialize the subentry flow."""
super().__init__()
# state for subentry flow
self.model_config: dict[str, Any] = {}
self.download_task = None
self.download_error = None
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
@property
def _client(self) -> LocalLLMClient:
"""Return the Ollama client."""
entry: LocalLLMConfigEntry = self._get_entry()
return entry.runtime_data
async def async_step_pick_model(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
schema = vol.Schema({})
errors = {}
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE]
if backend_type == BACKEND_TYPE_LLAMA_CPP:
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA()
else:
schema = STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(await entry.runtime_data.async_get_available_models())
if self.download_error:
if isinstance(self.download_error, MissingQuantizationException):
available_quants = list(set(self.download_error.available_quants).intersection(set(CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS)))
if len(available_quants) == 0:
errors["base"] = "no_supported_ggufs"
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
)
else:
errors["base"] = "missing_quantization"
description_placeholders["missing"] = self.download_error.missing_quant
description_placeholders["available"] = ", ".join(self.download_error.available_quants)
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.download_error.available_quants[0],
available_quantizations=available_quants,
)
else:
errors["base"] = "download_failed"
description_placeholders["exception"] = str(self.download_error)
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
)
if user_input and "result" not in user_input:
self.model_config.update(user_input)
if backend_type == BACKEND_TYPE_LLAMA_CPP:
model_file = self.model_config.get(CONF_DOWNLOADED_MODEL_FILE, "")
if not model_file:
model_name = self.model_config.get(CONF_CHAT_MODEL)
if model_name:
return await self.async_step_download()
else:
errors["base"] = "no_model_name_or_file"
if os.path.exists(model_file):
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
return await self.async_step_model_parameters()
else:
errors["base"] = "missing_model_file"
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file)
else:
return await self.async_step_model_parameters()
return self.async_show_form(
step_id="pick_model",
data_schema=schema,
errors=errors,
description_placeholders=description_placeholders,
last_step=False,
)
async def async_step_download(
self, user_input: dict[str, Any] | None = None,
) -> SubentryFlowResult:
if not self.download_task:
model_name = self.model_config[CONF_CHAT_MODEL]
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models")
async def download_task():
return await self.hass.async_add_executor_job(
download_model_from_hf, model_name, quantization_type, storage_folder
)
self.download_task = self.hass.async_create_background_task(
download_task(), name="model_download_task")
return self.async_show_progress(
progress_task=self.download_task,
step_id="download",
progress_action="download",
)
if self.download_task and not self.download_task.done():
return self.async_show_progress(
progress_task=self.download_task,
step_id="download",
progress_action="download",
)
download_exception = self.download_task.exception()
if download_exception:
_LOGGER.info("Failed to download model: %s", repr(download_exception))
self.download_error = download_exception
self.download_task = None
return self.async_show_progress_done(next_step_id="pick_model")
else:
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
self.download_task = None
return self.async_show_progress_done(next_step_id="model_parameters")
async def async_step_model_parameters(
self, user_input: dict[str, Any] | None = None,
) -> SubentryFlowResult:
errors = {}
description_placeholders = {}
entry = self._get_entry()
backend_type = entry.data[CONF_BACKEND_TYPE]
is_ai_task = self._subentry_type == ai_task.DOMAIN
if is_ai_task:
if CONF_PROMPT not in self.model_config:
self.model_config[CONF_PROMPT] = DEFAULT_AI_TASK_PROMPT
if CONF_AI_TASK_RETRIES not in self.model_config:
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
elif CONF_PROMPT not in self.model_config:
# determine selected language from model config or parent options
selected_language = self.model_config.get(
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
)
model_name = self.model_config.get(CONF_CHAT_MODEL, "").lower()
OPTIONS_OVERRIDES = option_overrides(backend_type)
selected_default_options = {**DEFAULT_OPTIONS}
for key in OPTIONS_OVERRIDES.keys():
if key in model_name:
selected_default_options.update(OPTIONS_OVERRIDES[key])
break
# Build prompt template using the selected language
selected_default_options[CONF_PROMPT] = build_prompt_template(
selected_language, str(selected_default_options.get(CONF_PROMPT, DEFAULT_PROMPT))
)
self.model_config = {**selected_default_options, **self.model_config}
schema = vol.Schema(
local_llama_config_option_schema(
self.hass,
entry.options.get(CONF_SELECTED_LANGUAGE, "en"),
self.model_config,
backend_type,
self._subentry_type,
)
)
if user_input:
if not is_ai_task:
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
errors["base"] = "sys_refresh_caching_enabled"
if user_input.get(CONF_USE_GBNF_GRAMMAR):
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
for key in (
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_CONTEXT_LENGTH,
CONF_MAX_TOKENS,
CONF_REQUEST_TIMEOUT,
CONF_AI_TASK_RETRIES,
):
if key in user_input:
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
if len(errors) == 0:
try:
# validate input
schema(user_input)
self.model_config.update(user_input)
return await self.async_step_finish()
except Exception:
_LOGGER.exception("An unknown error has occurred!")
errors["base"] = "unknown"
return self.async_show_form(
step_id="model_parameters",
data_schema=schema,
errors=errors,
description_placeholders=description_placeholders,
)
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Step after model downloading has succeeded."""
# Model download completed, create/update the entry with stored config
if self._is_new:
return self.async_create_entry(
title=self.model_config.get(CONF_CHAT_MODEL, "Model"),
data=self.model_config,
)
else:
return self.async_update_and_abort(
self._get_entry(), self._get_reconfigure_subentry(), data=self.model_config
)
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle model selection and configuration step."""
# Ensure the parent entry is loaded before allowing subentry edits
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
if not self.model_config:
self.model_config = {}
return await self.async_step_pick_model(user_input)
async_step_init = async_step_user
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None):
if not self.model_config:
self.model_config = dict(self._get_reconfigure_subentry().data)
return await self.async_step_model_parameters(user_input)