mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
1278 lines
55 KiB
Python
1278 lines
55 KiB
Python
"""Config flow for Local LLM Conversation integration."""
|
|
from __future__ import annotations
|
|
|
|
from asyncio import Task
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
|
|
from homeassistant.components import conversation, ai_task
|
|
from homeassistant.data_entry_flow import (
|
|
AbortFlow,
|
|
)
|
|
from homeassistant.config_entries import (
|
|
ConfigEntry,
|
|
ConfigFlow as BaseConfigFlow,
|
|
ConfigFlowResult,
|
|
ConfigSubentryFlow,
|
|
SubentryFlowResult,
|
|
ConfigEntriesFlowManager,
|
|
OptionsFlow as BaseOptionsFlow,
|
|
ConfigEntryState,
|
|
)
|
|
from homeassistant.helpers import llm
|
|
from homeassistant.helpers.selector import (
|
|
NumberSelector,
|
|
NumberSelectorConfig,
|
|
NumberSelectorMode,
|
|
TemplateSelector,
|
|
SelectOptionDict,
|
|
SelectSelector,
|
|
SelectSelectorConfig,
|
|
SelectSelectorMode,
|
|
TextSelector,
|
|
TextSelectorConfig,
|
|
TextSelectorType,
|
|
BooleanSelector,
|
|
BooleanSelectorConfig,
|
|
)
|
|
|
|
from .utils import download_model_from_hf, get_llama_cpp_python_version, install_llama_cpp_python, \
|
|
is_valid_hostname, get_available_llama_cpp_versions, MissingQuantizationException
|
|
from .const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_PROMPT,
|
|
DEFAULT_AI_TASK_PROMPT,
|
|
CONF_AI_TASK_RETRIES,
|
|
DEFAULT_AI_TASK_RETRIES,
|
|
CONF_AI_TASK_EXTRACTION_METHOD,
|
|
DEFAULT_AI_TASK_EXTRACTION_METHOD,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_K,
|
|
CONF_TOP_P,
|
|
CONF_MIN_P,
|
|
CONF_TYPICAL_P,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_BACKEND_TYPE,
|
|
CONF_INSTALLED_LLAMACPP_VERSION,
|
|
CONF_SELECTED_LANGUAGE,
|
|
CONF_SELECTED_LANGUAGE_OPTIONS,
|
|
CONF_DOWNLOADED_MODEL_FILE,
|
|
CONF_DOWNLOADED_MODEL_QUANTIZATION,
|
|
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS,
|
|
CONF_THINKING_PREFIX,
|
|
CONF_THINKING_SUFFIX,
|
|
CONF_TOOL_CALL_PREFIX,
|
|
CONF_TOOL_CALL_SUFFIX,
|
|
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
|
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_TEXT_GEN_WEBUI_PRESET,
|
|
CONF_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
CONF_MAX_TOOL_CALL_ITERATIONS,
|
|
CONF_PROMPT_CACHING_ENABLED,
|
|
CONF_PROMPT_CACHING_INTERVAL,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
CONF_OPENAI_API_KEY,
|
|
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE,
|
|
CONF_GENERIC_OPENAI_PATH,
|
|
CONF_CONTEXT_LENGTH,
|
|
CONF_LLAMACPP_BATCH_SIZE,
|
|
CONF_LLAMACPP_THREAD_COUNT,
|
|
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
|
CONF_LLAMACPP_REINSTALL,
|
|
DEFAULT_CHAT_MODEL,
|
|
DEFAULT_PORT,
|
|
DEFAULT_SSL,
|
|
DEFAULT_MAX_TOKENS,
|
|
PERSONA_PROMPTS,
|
|
CURRENT_DATE_PROMPT,
|
|
DEVICES_PROMPT,
|
|
SERVICES_PROMPT,
|
|
TOOLS_PROMPT,
|
|
AREA_PROMPT,
|
|
USER_INSTRUCTION,
|
|
DEFAULT_PROMPT,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOP_K,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_MIN_P,
|
|
DEFAULT_TYPICAL_P,
|
|
DEFAULT_REQUEST_TIMEOUT,
|
|
DEFAULT_BACKEND_TYPE,
|
|
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
|
|
DEFAULT_THINKING_PREFIX,
|
|
DEFAULT_THINKING_SUFFIX,
|
|
DEFAULT_TOOL_CALL_PREFIX,
|
|
DEFAULT_TOOL_CALL_SUFFIX,
|
|
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
|
DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
|
DEFAULT_USE_GBNF_GRAMMAR,
|
|
DEFAULT_GBNF_GRAMMAR_FILE,
|
|
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
|
DEFAULT_REMEMBER_CONVERSATION,
|
|
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
|
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
|
DEFAULT_PROMPT_CACHING_ENABLED,
|
|
DEFAULT_PROMPT_CACHING_INTERVAL,
|
|
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
DEFAULT_OLLAMA_JSON_MODE,
|
|
DEFAULT_GENERIC_OPENAI_PATH,
|
|
DEFAULT_CONTEXT_LENGTH,
|
|
DEFAULT_LLAMACPP_BATCH_SIZE,
|
|
DEFAULT_LLAMACPP_THREAD_COUNT,
|
|
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
|
BACKEND_TYPE_LLAMA_CPP,
|
|
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
|
BACKEND_TYPE_GENERIC_OPENAI,
|
|
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
|
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
|
BACKEND_TYPE_OLLAMA,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
|
DOMAIN,
|
|
HOME_LLM_API_ID,
|
|
DEFAULT_OPTIONS,
|
|
option_overrides,
|
|
RECOMMENDED_CHAT_MODELS,
|
|
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
|
)
|
|
|
|
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
def _coerce_int(val, default=0):
|
|
try:
|
|
return int(val)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
def pick_backend_schema(backend_type=None, selected_language=None):
|
|
return vol.Schema(
|
|
{
|
|
vol.Required(
|
|
CONF_BACKEND_TYPE,
|
|
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
|
|
): SelectSelector(SelectSelectorConfig(
|
|
options=[
|
|
BACKEND_TYPE_LLAMA_CPP,
|
|
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
|
BACKEND_TYPE_GENERIC_OPENAI,
|
|
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
|
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
|
BACKEND_TYPE_OLLAMA
|
|
],
|
|
translation_key=CONF_BACKEND_TYPE,
|
|
multiple=False,
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig(
|
|
options=CONF_SELECTED_LANGUAGE_OPTIONS,
|
|
translation_key=CONF_SELECTED_LANGUAGE,
|
|
multiple=False,
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
}
|
|
)
|
|
|
|
def remote_connection_schema(backend_type: str, *, host=None, port=None, ssl=None, selected_path=None):
|
|
|
|
extra = {}
|
|
default_port = DEFAULT_PORT
|
|
default_path = DEFAULT_GENERIC_OPENAI_PATH
|
|
|
|
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
|
extra[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD))
|
|
elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER:
|
|
default_port = "8000"
|
|
elif backend_type == BACKEND_TYPE_OLLAMA:
|
|
default_port = "11434"
|
|
default_path = ""
|
|
elif backend_type in [BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_GENERIC_OPENAI_RESPONSES]:
|
|
default_port = ""
|
|
|
|
return vol.Schema(
|
|
{
|
|
vol.Required(CONF_HOST, default=host if host else ""): str,
|
|
vol.Optional(CONF_PORT, default=port if port else default_port): str,
|
|
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
|
|
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD)),
|
|
vol.Optional(
|
|
CONF_GENERIC_OPENAI_PATH,
|
|
default=selected_path if selected_path else default_path
|
|
): TextSelector(TextSelectorConfig(prefix="/")),
|
|
**extra
|
|
}
|
|
)
|
|
|
|
class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
|
"""Handle a config flow for Local LLM Conversation."""
|
|
|
|
VERSION = 3
|
|
MINOR_VERSION = 2
|
|
|
|
install_wheel_task = None
|
|
install_wheel_error = None
|
|
client_config: dict[str, Any]
|
|
internal_step: str = "init"
|
|
|
|
@property
|
|
def flow_manager(self) -> ConfigEntriesFlowManager:
|
|
"""Return the correct flow manager."""
|
|
return self.hass.config_entries.flow
|
|
|
|
def async_remove(self) -> None:
|
|
if self.install_wheel_task:
|
|
self.install_wheel_task.cancel()
|
|
|
|
async def async_step_user(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> ConfigFlowResult:
|
|
"""Handle the initial step."""
|
|
errors = {}
|
|
description_placeholders = {}
|
|
|
|
if self.internal_step == "init":
|
|
self.client_config = {}
|
|
|
|
# make sure the API is registered
|
|
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]):
|
|
llm.async_register_api(self.hass, HomeLLMAPI(self.hass))
|
|
|
|
self.internal_step = "pick_backend"
|
|
return self.async_show_form(
|
|
step_id="user", data_schema=pick_backend_schema(), last_step=False
|
|
)
|
|
elif self.internal_step == "pick_backend":
|
|
if user_input:
|
|
backend = user_input[CONF_BACKEND_TYPE]
|
|
self.client_config.update(user_input)
|
|
if backend == BACKEND_TYPE_LLAMA_CPP:
|
|
installed_version = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
|
|
_LOGGER.debug(f"installed version: {installed_version}")
|
|
if installed_version and installed_version == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
|
|
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = installed_version
|
|
return await self.async_step_finish()
|
|
else:
|
|
self.internal_step = "install_local_wheels"
|
|
_LOGGER.debug("Queuing install task...")
|
|
async def install_task():
|
|
return await self.hass.async_add_executor_job(
|
|
install_llama_cpp_python, self.hass.config.config_dir
|
|
)
|
|
|
|
self.install_wheel_task = self.hass.async_create_background_task(
|
|
install_task(), name="llama_cpp_python_installation")
|
|
|
|
return self.async_show_progress(
|
|
progress_task=self.install_wheel_task,
|
|
step_id="user",
|
|
progress_action="install_local_wheels",
|
|
)
|
|
else:
|
|
self.internal_step = "configure_connection"
|
|
return self.async_show_form(
|
|
step_id="user", data_schema=remote_connection_schema(self.client_config[CONF_BACKEND_TYPE]), last_step=True
|
|
)
|
|
elif self.install_wheel_error:
|
|
errors["base"] = str(self.install_wheel_error)
|
|
self.install_wheel_error = None
|
|
|
|
return self.async_show_form(
|
|
step_id="user", data_schema=pick_backend_schema(
|
|
backend_type=self.client_config.get(CONF_BACKEND_TYPE),
|
|
selected_language=self.client_config.get(CONF_SELECTED_LANGUAGE)
|
|
), errors=errors, last_step=False)
|
|
elif self.internal_step == "install_local_wheels":
|
|
if not self.install_wheel_task:
|
|
return self.async_abort(reason="unknown")
|
|
|
|
if not self.install_wheel_task.done():
|
|
return self.async_show_progress(
|
|
progress_task=self.install_wheel_task,
|
|
step_id="user",
|
|
progress_action="install_local_wheels",
|
|
)
|
|
|
|
install_exception = self.install_wheel_task.exception()
|
|
if install_exception:
|
|
_LOGGER.warning("Failed to install wheel: %s", repr(install_exception))
|
|
self.install_wheel_error = "pip_wheel_error"
|
|
next_step = "pick_backend"
|
|
else:
|
|
wheel_install_result = self.install_wheel_task.result()
|
|
if not wheel_install_result:
|
|
self.install_wheel_error = "pip_wheel_error"
|
|
next_step = "pick_backend"
|
|
else:
|
|
_LOGGER.debug(f"Finished install: {wheel_install_result}")
|
|
next_step = "finish"
|
|
self.client_config[CONF_INSTALLED_LLAMACPP_VERSION] = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
|
|
|
|
self.install_wheel_task = None
|
|
self.internal_step = next_step
|
|
return self.async_show_progress_done(next_step_id="finish")
|
|
elif self.internal_step == "configure_connection":
|
|
if user_input:
|
|
self.client_config.update(user_input)
|
|
|
|
hostname = user_input.get(CONF_HOST, "")
|
|
if not is_valid_hostname(hostname):
|
|
errors["base"] = "invalid_hostname"
|
|
else:
|
|
# validate remote connections
|
|
connect_err = await BACKEND_TO_CLS[self.client_config[CONF_BACKEND_TYPE]].async_validate_connection(self.hass, self.client_config)
|
|
|
|
if connect_err:
|
|
errors["base"] = "failed_to_connect"
|
|
description_placeholders["exception"] = str(connect_err)
|
|
else:
|
|
return await self.async_step_finish()
|
|
|
|
return self.async_show_form(
|
|
step_id="user",
|
|
data_schema=remote_connection_schema(
|
|
self.client_config[CONF_BACKEND_TYPE],
|
|
host=self.client_config.get(CONF_HOST),
|
|
port=self.client_config.get(CONF_PORT),
|
|
ssl=self.client_config.get(CONF_SSL),
|
|
selected_path=self.client_config.get(CONF_GENERIC_OPENAI_PATH)
|
|
),
|
|
errors=errors,
|
|
description_placeholders=description_placeholders,
|
|
last_step=True
|
|
)
|
|
else:
|
|
raise AbortFlow("Unknown internal step")
|
|
|
|
async def async_step_finish(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> ConfigFlowResult:
|
|
|
|
backend = self.client_config[CONF_BACKEND_TYPE]
|
|
title = BACKEND_TO_CLS[backend].get_name(self.client_config)
|
|
_LOGGER.debug(f"creating provider with config: {self.client_config}")
|
|
|
|
# block duplicate providers
|
|
for entry in self.hass.config_entries.async_entries(DOMAIN):
|
|
if backend == BACKEND_TYPE_LLAMA_CPP and \
|
|
entry.data.get(CONF_BACKEND_TYPE) == BACKEND_TYPE_LLAMA_CPP:
|
|
return self.async_abort(reason="duplicate_client")
|
|
|
|
return self.async_create_entry(
|
|
title=title,
|
|
description="A Large Language Model Chat Agent",
|
|
data={CONF_BACKEND_TYPE: backend},
|
|
options=self.client_config,
|
|
)
|
|
|
|
@classmethod
|
|
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
|
|
return True
|
|
|
|
@staticmethod
|
|
def async_get_options_flow(
|
|
config_entry: ConfigEntry,
|
|
) -> BaseOptionsFlow:
|
|
"""Create the options flow."""
|
|
return OptionsFlow()
|
|
|
|
@classmethod
|
|
def async_get_supported_subentry_types(
|
|
cls, config_entry: ConfigEntry
|
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
|
"""Return subentries supported by this integration."""
|
|
return {
|
|
conversation.DOMAIN: LocalLLMSubentryFlowHandler,
|
|
ai_task.DOMAIN: LocalLLMSubentryFlowHandler,
|
|
}
|
|
|
|
|
|
class OptionsFlow(BaseOptionsFlow):
|
|
"""Local LLM config flow options handler."""
|
|
|
|
model_config: dict[str, Any] | None = None
|
|
reinstall_task: Task[Any] | None = None
|
|
wheel_install_error: str | None = None
|
|
wheel_install_successful: bool = False
|
|
|
|
async def async_step_init(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> ConfigFlowResult:
|
|
"""Manage the options."""
|
|
errors = {}
|
|
description_placeholders = {}
|
|
|
|
backend_type = self.config_entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
|
client_config = dict(self.config_entry.options)
|
|
|
|
if self.wheel_install_error:
|
|
_LOGGER.warning("Failed to install wheel: %s", repr(self.wheel_install_error))
|
|
return self.async_abort(reason="pip_wheel_error")
|
|
|
|
if self.wheel_install_successful:
|
|
client_config[CONF_INSTALLED_LLAMACPP_VERSION] = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
|
|
_LOGGER.debug(f"new version is: {client_config[CONF_INSTALLED_LLAMACPP_VERSION]}")
|
|
return self.async_create_entry(data=client_config)
|
|
|
|
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
|
potential_versions = await get_available_llama_cpp_versions(self.hass)
|
|
|
|
schema = vol.Schema({
|
|
vol.Required(CONF_LLAMACPP_REINSTALL, default=False): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(CONF_INSTALLED_LLAMACPP_VERSION, default=client_config.get(CONF_INSTALLED_LLAMACPP_VERSION, "not installed")): SelectSelector(
|
|
SelectSelectorConfig(
|
|
options=[ SelectOptionDict(value=x[0], label=x[0] if not x[1] else f"{x[0]} (local)") for x in potential_versions ],
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)
|
|
)
|
|
})
|
|
|
|
return self.async_show_form(
|
|
step_id="reinstall",
|
|
data_schema=schema,
|
|
)
|
|
else:
|
|
|
|
if user_input is not None:
|
|
client_config.update(user_input)
|
|
|
|
# validate remote connections
|
|
connect_err = await BACKEND_TO_CLS[backend_type].async_validate_connection(self.hass, client_config)
|
|
|
|
if not connect_err:
|
|
return self.async_create_entry(data=client_config)
|
|
else:
|
|
errors["base"] = "failed_to_connect"
|
|
description_placeholders["exception"] = str(connect_err)
|
|
|
|
schema = remote_connection_schema(
|
|
backend_type=backend_type,
|
|
host=client_config.get(CONF_HOST),
|
|
port=client_config.get(CONF_PORT),
|
|
ssl=client_config.get(CONF_SSL),
|
|
selected_path=client_config.get(CONF_GENERIC_OPENAI_PATH)
|
|
)
|
|
|
|
return self.async_show_form(
|
|
step_id="init",
|
|
data_schema=schema,
|
|
errors=errors,
|
|
description_placeholders=description_placeholders,
|
|
)
|
|
|
|
async def async_step_reinstall(self, user_input: dict[str, Any] | None = None) -> ConfigFlowResult:
|
|
client_config = dict(self.config_entry.options)
|
|
|
|
if user_input is not None:
|
|
if not user_input[CONF_LLAMACPP_REINSTALL]:
|
|
_LOGGER.debug("Reinstall was not selected, finishing")
|
|
return self.async_create_entry(data=client_config)
|
|
|
|
if not self.reinstall_task:
|
|
if not user_input:
|
|
return self.async_abort(reason="unknown")
|
|
|
|
desired_version = user_input.get(CONF_INSTALLED_LLAMACPP_VERSION)
|
|
async def install_task():
|
|
return await self.hass.async_add_executor_job(
|
|
install_llama_cpp_python, self.hass.config.config_dir, True, desired_version
|
|
)
|
|
|
|
self.reinstall_task = self.hass.async_create_background_task(
|
|
install_task(), name="llama_cpp_python_installation")
|
|
|
|
_LOGGER.debug("Queuing reinstall task...")
|
|
return self.async_show_progress(
|
|
progress_task=self.reinstall_task,
|
|
step_id="reinstall",
|
|
progress_action="install_local_wheels",
|
|
)
|
|
|
|
if not self.reinstall_task.done():
|
|
return self.async_show_progress(
|
|
progress_task=self.reinstall_task,
|
|
step_id="reinstall",
|
|
progress_action="install_local_wheels",
|
|
)
|
|
|
|
_LOGGER.debug("done... checking result")
|
|
install_exception = self.reinstall_task.exception()
|
|
if install_exception:
|
|
self.wheel_install_error = repr(install_exception)
|
|
_LOGGER.debug(f"Hit error: {self.wheel_install_error}")
|
|
return self.async_show_progress_done(next_step_id="init")
|
|
else:
|
|
wheel_install_result = self.reinstall_task.result()
|
|
if not wheel_install_result:
|
|
self.wheel_install_error = "Pip returned false"
|
|
_LOGGER.debug(f"Hit error: {self.wheel_install_error} ({wheel_install_result})")
|
|
return self.async_show_progress_done(next_step_id="init")
|
|
else:
|
|
_LOGGER.debug(f"Finished install: {wheel_install_result}")
|
|
self.wheel_install_successful = True
|
|
return self.async_show_progress_done(next_step_id="init")
|
|
|
|
|
|
def STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file=None, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
|
|
return vol.Schema(
|
|
{
|
|
vol.Optional(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
|
|
options=RECOMMENDED_CHAT_MODELS,
|
|
custom_value=True,
|
|
multiple=False,
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
vol.Optional(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
|
|
vol.Optional(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
|
|
}
|
|
)
|
|
|
|
def STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(available_models: list[str], chat_model: str | None = None):
|
|
_LOGGER.debug(f"available models: {available_models}")
|
|
return vol.Schema(
|
|
{
|
|
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else available_models[0]): SelectSelector(SelectSelectorConfig(
|
|
options=available_models,
|
|
custom_value=True,
|
|
multiple=False,
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
}
|
|
)
|
|
|
|
def build_prompt_template(selected_language: str, prompt_template_template: str):
|
|
persona = PERSONA_PROMPTS.get(selected_language, PERSONA_PROMPTS["en"])
|
|
current_date = CURRENT_DATE_PROMPT.get(selected_language, CURRENT_DATE_PROMPT["en"])
|
|
devices = DEVICES_PROMPT.get(selected_language, DEVICES_PROMPT["en"])
|
|
services = SERVICES_PROMPT.get(selected_language, SERVICES_PROMPT["en"])
|
|
tools = TOOLS_PROMPT.get(selected_language, TOOLS_PROMPT["en"])
|
|
area = AREA_PROMPT.get(selected_language, AREA_PROMPT["en"])
|
|
user_instruction = USER_INSTRUCTION.get(selected_language, USER_INSTRUCTION["en"])
|
|
|
|
prompt_template_template = prompt_template_template.replace("<persona>", persona)
|
|
prompt_template_template = prompt_template_template.replace("<current_date>", current_date)
|
|
prompt_template_template = prompt_template_template.replace("<devices>", devices)
|
|
prompt_template_template = prompt_template_template.replace("<services>", services)
|
|
prompt_template_template = prompt_template_template.replace("<tools>", tools)
|
|
prompt_template_template = prompt_template_template.replace("<area>", area)
|
|
prompt_template_template = prompt_template_template.replace("<user_instruction>", user_instruction)
|
|
|
|
return prompt_template_template
|
|
|
|
def local_llama_config_option_schema(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
options: dict[str, Any],
|
|
backend_type: str,
|
|
subentry_type: str,
|
|
) -> dict:
|
|
|
|
result: dict = {
|
|
vol.Optional(
|
|
CONF_TEMPERATURE,
|
|
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
|
|
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
|
): NumberSelector(NumberSelectorConfig(min=0.0, max=2.0, step=0.05, mode=NumberSelectorMode.BOX)),
|
|
vol.Required(
|
|
CONF_THINKING_PREFIX,
|
|
description={"suggested_value": options.get(CONF_THINKING_PREFIX)},
|
|
default=DEFAULT_THINKING_PREFIX,
|
|
): str,
|
|
vol.Required(
|
|
CONF_THINKING_SUFFIX,
|
|
description={"suggested_value": options.get(CONF_THINKING_SUFFIX)},
|
|
default=DEFAULT_THINKING_SUFFIX,
|
|
): str,
|
|
vol.Required(
|
|
CONF_TOOL_CALL_PREFIX,
|
|
description={"suggested_value": options.get(CONF_TOOL_CALL_PREFIX)},
|
|
default=DEFAULT_TOOL_CALL_PREFIX,
|
|
): str,
|
|
vol.Required(
|
|
CONF_TOOL_CALL_SUFFIX,
|
|
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
|
|
default=DEFAULT_TOOL_CALL_SUFFIX,
|
|
): str,
|
|
vol.Required(
|
|
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
|
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
|
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
|
): bool,
|
|
}
|
|
|
|
if subentry_type == ai_task.DOMAIN:
|
|
result.update({
|
|
vol.Optional(
|
|
CONF_PROMPT,
|
|
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT)},
|
|
default=options.get(CONF_PROMPT, DEFAULT_AI_TASK_PROMPT),
|
|
): TemplateSelector(),
|
|
vol.Required(
|
|
CONF_AI_TASK_EXTRACTION_METHOD,
|
|
description={"suggested_value": options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD)},
|
|
default=options.get(CONF_AI_TASK_EXTRACTION_METHOD, DEFAULT_AI_TASK_EXTRACTION_METHOD),
|
|
): SelectSelector(SelectSelectorConfig(
|
|
options=[
|
|
SelectOptionDict(value="none", label="None"),
|
|
SelectOptionDict(value="structure", label="Structured output"),
|
|
SelectOptionDict(value="tool", label="Tool call"),
|
|
],
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
vol.Required(
|
|
CONF_AI_TASK_RETRIES,
|
|
description={"suggested_value": options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES)},
|
|
default=options.get(CONF_AI_TASK_RETRIES, DEFAULT_AI_TASK_RETRIES),
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=5, step=1, mode=NumberSelectorMode.BOX)),
|
|
})
|
|
elif subentry_type == conversation.DOMAIN:
|
|
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
|
|
apis: list[SelectOptionDict] = [
|
|
SelectOptionDict(
|
|
label=api.name,
|
|
value=api.id,
|
|
)
|
|
for api in llm.async_get_apis(hass)
|
|
]
|
|
result.update({
|
|
vol.Optional(
|
|
CONF_PROMPT,
|
|
description={"suggested_value": options.get(CONF_PROMPT, default_prompt)},
|
|
default=options.get(CONF_PROMPT, default_prompt),
|
|
): TemplateSelector(),
|
|
vol.Required(
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
|
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
|
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
|
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
): str,
|
|
vol.Required(
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
|
|
default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
|
|
vol.Required(
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
|
|
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
): TextSelector(TextSelectorConfig(multiple=True)),
|
|
vol.Optional(
|
|
CONF_LLM_HASS_API,
|
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
|
default=None,
|
|
): SelectSelector(SelectSelectorConfig(options=apis, multiple=True)),
|
|
vol.Optional(
|
|
CONF_REFRESH_SYSTEM_PROMPT,
|
|
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)},
|
|
default=options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT),
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Optional(
|
|
CONF_REMEMBER_CONVERSATION,
|
|
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)},
|
|
default=options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION),
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Optional(
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)},
|
|
default=options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS),
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=100, mode=NumberSelectorMode.BOX)),
|
|
vol.Optional(
|
|
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION)},
|
|
default=options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION),
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1440, mode=NumberSelectorMode.BOX)),
|
|
vol.Required(
|
|
CONF_MAX_TOOL_CALL_ITERATIONS,
|
|
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
|
|
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
|
): int,
|
|
})
|
|
|
|
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
|
if subentry_type == conversation.DOMAIN:
|
|
result.update({
|
|
vol.Required(
|
|
CONF_PROMPT_CACHING_ENABLED,
|
|
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
|
|
default=DEFAULT_PROMPT_CACHING_ENABLED,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_PROMPT_CACHING_INTERVAL,
|
|
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
|
|
default=DEFAULT_PROMPT_CACHING_INTERVAL,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
|
|
})
|
|
result.update({
|
|
vol.Required(
|
|
CONF_MAX_TOKENS,
|
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
|
default=DEFAULT_MAX_TOKENS,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_K,
|
|
description={"suggested_value": options.get(CONF_TOP_K)},
|
|
default=DEFAULT_TOP_K,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_MIN_P,
|
|
description={"suggested_value": options.get(CONF_MIN_P)},
|
|
default=DEFAULT_MIN_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_TYPICAL_P,
|
|
description={"suggested_value": options.get(CONF_TYPICAL_P)},
|
|
default=DEFAULT_TYPICAL_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
# TODO: add rope_scaling_type
|
|
vol.Required(
|
|
CONF_CONTEXT_LENGTH,
|
|
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
|
default=DEFAULT_CONTEXT_LENGTH,
|
|
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
|
vol.Required(
|
|
CONF_LLAMACPP_BATCH_SIZE,
|
|
description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_SIZE)},
|
|
default=DEFAULT_LLAMACPP_BATCH_SIZE,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
|
vol.Required(
|
|
CONF_LLAMACPP_THREAD_COUNT,
|
|
description={"suggested_value": options.get(CONF_LLAMACPP_THREAD_COUNT)},
|
|
default=DEFAULT_LLAMACPP_THREAD_COUNT,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=((os.cpu_count() or 1) * 2), step=1)),
|
|
vol.Required(
|
|
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
|
description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_THREAD_COUNT)},
|
|
default=DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=((os.cpu_count() or 1) * 2), step=1)),
|
|
vol.Required(
|
|
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
|
description={"suggested_value": options.get(CONF_LLAMACPP_ENABLE_FLASH_ATTENTION)},
|
|
default=DEFAULT_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
|
default=DEFAULT_USE_GBNF_GRAMMAR,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
|
|
default=DEFAULT_GBNF_GRAMMAR_FILE,
|
|
): str
|
|
})
|
|
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
|
result.update({
|
|
vol.Required(
|
|
CONF_CONTEXT_LENGTH,
|
|
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
|
default=DEFAULT_CONTEXT_LENGTH,
|
|
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
|
vol.Required(
|
|
CONF_TOP_K,
|
|
description={"suggested_value": options.get(CONF_TOP_K)},
|
|
default=DEFAULT_TOP_K,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_MIN_P,
|
|
description={"suggested_value": options.get(CONF_MIN_P)},
|
|
default=DEFAULT_MIN_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_TYPICAL_P,
|
|
description={"suggested_value": options.get(CONF_TYPICAL_P)},
|
|
default=DEFAULT_TYPICAL_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_REQUEST_TIMEOUT,
|
|
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
|
default=DEFAULT_REQUEST_TIMEOUT,
|
|
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
|
vol.Optional(
|
|
CONF_TEXT_GEN_WEBUI_PRESET,
|
|
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_PRESET)},
|
|
): str,
|
|
vol.Required(
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE)},
|
|
default=DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
): SelectSelector(SelectSelectorConfig(
|
|
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
|
|
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
multiple=False,
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)),
|
|
})
|
|
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI:
|
|
result.update({
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_REQUEST_TIMEOUT,
|
|
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
|
default=DEFAULT_REQUEST_TIMEOUT,
|
|
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
|
})
|
|
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
|
del result[CONF_REMEMBER_NUM_INTERACTIONS]
|
|
result.update({
|
|
vol.Required(
|
|
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=180, step=0.5, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
|
|
vol.Required(
|
|
CONF_TEMPERATURE,
|
|
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
|
default=DEFAULT_TEMPERATURE,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_REQUEST_TIMEOUT,
|
|
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
|
default=DEFAULT_REQUEST_TIMEOUT,
|
|
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
|
})
|
|
elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER:
|
|
result.update({
|
|
vol.Required(
|
|
CONF_MAX_TOKENS,
|
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
|
default=DEFAULT_MAX_TOKENS,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_K,
|
|
description={"suggested_value": options.get(CONF_TOP_K)},
|
|
default=DEFAULT_TOP_K,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
|
default=DEFAULT_USE_GBNF_GRAMMAR,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
|
|
default=DEFAULT_GBNF_GRAMMAR_FILE,
|
|
): str,
|
|
vol.Required(
|
|
CONF_REQUEST_TIMEOUT,
|
|
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
|
default=DEFAULT_REQUEST_TIMEOUT,
|
|
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
|
})
|
|
elif backend_type == BACKEND_TYPE_OLLAMA:
|
|
result.update({
|
|
vol.Required(
|
|
CONF_MAX_TOKENS,
|
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
|
default=DEFAULT_MAX_TOKENS,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
|
vol.Required(
|
|
CONF_CONTEXT_LENGTH,
|
|
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
|
default=DEFAULT_CONTEXT_LENGTH,
|
|
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
|
vol.Required(
|
|
CONF_TOP_K,
|
|
description={"suggested_value": options.get(CONF_TOP_K)},
|
|
default=DEFAULT_TOP_K,
|
|
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
|
|
vol.Required(
|
|
CONF_TOP_P,
|
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
default=DEFAULT_TOP_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_TYPICAL_P,
|
|
description={"suggested_value": options.get(CONF_TYPICAL_P)},
|
|
default=DEFAULT_TYPICAL_P,
|
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
vol.Required(
|
|
CONF_OLLAMA_JSON_MODE,
|
|
description={"suggested_value": options.get(CONF_OLLAMA_JSON_MODE)},
|
|
default=DEFAULT_OLLAMA_JSON_MODE,
|
|
): BooleanSelector(BooleanSelectorConfig()),
|
|
vol.Required(
|
|
CONF_REQUEST_TIMEOUT,
|
|
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
|
default=DEFAULT_REQUEST_TIMEOUT,
|
|
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
|
vol.Required(
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
description={"suggested_value": options.get(CONF_OLLAMA_KEEP_ALIVE_MIN)},
|
|
default=DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
|
|
})
|
|
|
|
# sort the options
|
|
global_order = [
|
|
# general
|
|
CONF_LLM_HASS_API,
|
|
CONF_PROMPT,
|
|
CONF_AI_TASK_EXTRACTION_METHOD,
|
|
CONF_AI_TASK_RETRIES,
|
|
CONF_CONTEXT_LENGTH,
|
|
CONF_MAX_TOKENS,
|
|
# sampling parameters
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_P,
|
|
CONF_MIN_P,
|
|
CONF_TYPICAL_P,
|
|
CONF_TOP_K,
|
|
# tool calling/reasoning
|
|
CONF_THINKING_PREFIX,
|
|
CONF_THINKING_SUFFIX,
|
|
CONF_TOOL_CALL_PREFIX,
|
|
CONF_TOOL_CALL_SUFFIX,
|
|
CONF_MAX_TOOL_CALL_ITERATIONS,
|
|
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
# integration specific options
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
CONF_PROMPT_CACHING_ENABLED,
|
|
CONF_PROMPT_CACHING_INTERVAL,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
# backend specific options
|
|
CONF_LLAMACPP_BATCH_SIZE,
|
|
CONF_LLAMACPP_THREAD_COUNT,
|
|
CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
|
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION,
|
|
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
|
CONF_TEXT_GEN_WEBUI_PRESET,
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE,
|
|
]
|
|
|
|
result = { k: v for k, v in sorted(result.items(), key=lambda item: global_order.index(item[0]) if item[0] in global_order else 9999) }
|
|
|
|
return result
|
|
|
|
|
|
class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
|
"""Flow for managing Local LLM subentries."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the subentry flow."""
|
|
super().__init__()
|
|
|
|
# state for subentry flow
|
|
self.model_config: dict[str, Any] = {}
|
|
self.download_task = None
|
|
self.download_error = None
|
|
|
|
@property
|
|
def _is_new(self) -> bool:
|
|
"""Return if this is a new subentry."""
|
|
return self.source == "user"
|
|
|
|
@property
|
|
def _client(self) -> LocalLLMClient:
|
|
"""Return the Ollama client."""
|
|
entry: LocalLLMConfigEntry = self._get_entry()
|
|
return entry.runtime_data
|
|
|
|
async def async_step_pick_model(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> SubentryFlowResult:
|
|
schema = vol.Schema({})
|
|
errors = {}
|
|
description_placeholders = {}
|
|
entry = self._get_entry()
|
|
|
|
backend_type = entry.data[CONF_BACKEND_TYPE]
|
|
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
|
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA()
|
|
else:
|
|
schema = STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(await entry.runtime_data.async_get_available_models())
|
|
|
|
if self.download_error:
|
|
if isinstance(self.download_error, MissingQuantizationException):
|
|
available_quants = list(set(self.download_error.available_quants).intersection(set(CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS)))
|
|
|
|
if len(available_quants) == 0:
|
|
errors["base"] = "no_supported_ggufs"
|
|
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
|
chat_model=self.model_config[CONF_CHAT_MODEL],
|
|
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
|
|
)
|
|
else:
|
|
errors["base"] = "missing_quantization"
|
|
description_placeholders["missing"] = self.download_error.missing_quant
|
|
description_placeholders["available"] = ", ".join(self.download_error.available_quants)
|
|
|
|
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
|
chat_model=self.model_config[CONF_CHAT_MODEL],
|
|
downloaded_model_quantization=self.download_error.available_quants[0],
|
|
available_quantizations=available_quants,
|
|
)
|
|
else:
|
|
errors["base"] = "download_failed"
|
|
description_placeholders["exception"] = str(self.download_error)
|
|
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
|
chat_model=self.model_config[CONF_CHAT_MODEL],
|
|
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
|
|
)
|
|
|
|
if user_input and "result" not in user_input:
|
|
|
|
self.model_config.update(user_input)
|
|
|
|
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
|
model_file = self.model_config.get(CONF_DOWNLOADED_MODEL_FILE, "")
|
|
if not model_file:
|
|
model_name = self.model_config.get(CONF_CHAT_MODEL)
|
|
if model_name:
|
|
return await self.async_step_download()
|
|
else:
|
|
errors["base"] = "no_model_name_or_file"
|
|
|
|
if os.path.exists(model_file):
|
|
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
|
|
return await self.async_step_model_parameters()
|
|
else:
|
|
errors["base"] = "missing_model_file"
|
|
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file)
|
|
else:
|
|
return await self.async_step_model_parameters()
|
|
|
|
return self.async_show_form(
|
|
step_id="pick_model",
|
|
data_schema=schema,
|
|
errors=errors,
|
|
description_placeholders=description_placeholders,
|
|
last_step=False,
|
|
)
|
|
|
|
async def async_step_download(
|
|
self, user_input: dict[str, Any] | None = None,
|
|
) -> SubentryFlowResult:
|
|
if not self.download_task:
|
|
model_name = self.model_config[CONF_CHAT_MODEL]
|
|
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
|
storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models")
|
|
|
|
async def download_task():
|
|
return await self.hass.async_add_executor_job(
|
|
download_model_from_hf, model_name, quantization_type, storage_folder
|
|
)
|
|
|
|
self.download_task = self.hass.async_create_background_task(
|
|
download_task(), name="model_download_task")
|
|
|
|
return self.async_show_progress(
|
|
progress_task=self.download_task,
|
|
step_id="download",
|
|
progress_action="download",
|
|
)
|
|
|
|
if self.download_task and not self.download_task.done():
|
|
return self.async_show_progress(
|
|
progress_task=self.download_task,
|
|
step_id="download",
|
|
progress_action="download",
|
|
)
|
|
|
|
download_exception = self.download_task.exception()
|
|
if download_exception:
|
|
_LOGGER.info("Failed to download model: %s", repr(download_exception))
|
|
self.download_error = download_exception
|
|
|
|
self.download_task = None
|
|
return self.async_show_progress_done(next_step_id="pick_model")
|
|
else:
|
|
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
|
|
|
|
self.download_task = None
|
|
return self.async_show_progress_done(next_step_id="model_parameters")
|
|
|
|
async def async_step_model_parameters(
|
|
self, user_input: dict[str, Any] | None = None,
|
|
) -> SubentryFlowResult:
|
|
errors = {}
|
|
description_placeholders = {}
|
|
entry = self._get_entry()
|
|
backend_type = entry.data[CONF_BACKEND_TYPE]
|
|
is_ai_task = self._subentry_type == ai_task.DOMAIN
|
|
|
|
if is_ai_task:
|
|
if CONF_PROMPT not in self.model_config:
|
|
self.model_config[CONF_PROMPT] = DEFAULT_AI_TASK_PROMPT
|
|
if CONF_AI_TASK_RETRIES not in self.model_config:
|
|
self.model_config[CONF_AI_TASK_RETRIES] = DEFAULT_AI_TASK_RETRIES
|
|
if CONF_AI_TASK_EXTRACTION_METHOD not in self.model_config:
|
|
self.model_config[CONF_AI_TASK_EXTRACTION_METHOD] = DEFAULT_AI_TASK_EXTRACTION_METHOD
|
|
elif CONF_PROMPT not in self.model_config:
|
|
# determine selected language from model config or parent options
|
|
selected_language = self.model_config.get(
|
|
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
|
|
)
|
|
model_name = self.model_config.get(CONF_CHAT_MODEL, "").lower()
|
|
|
|
OPTIONS_OVERRIDES = option_overrides(backend_type)
|
|
|
|
selected_default_options = {**DEFAULT_OPTIONS}
|
|
for key in OPTIONS_OVERRIDES.keys():
|
|
if key in model_name:
|
|
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
|
break
|
|
|
|
# Build prompt template using the selected language
|
|
selected_default_options[CONF_PROMPT] = build_prompt_template(
|
|
selected_language, str(selected_default_options.get(CONF_PROMPT, DEFAULT_PROMPT))
|
|
)
|
|
|
|
self.model_config = {**selected_default_options, **self.model_config}
|
|
|
|
schema = vol.Schema(
|
|
local_llama_config_option_schema(
|
|
self.hass,
|
|
entry.options.get(CONF_SELECTED_LANGUAGE, "en"),
|
|
self.model_config,
|
|
backend_type,
|
|
self._subentry_type,
|
|
)
|
|
)
|
|
|
|
if user_input:
|
|
if not is_ai_task:
|
|
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
|
errors["base"] = "sys_refresh_caching_enabled"
|
|
|
|
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
|
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
|
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
|
errors["base"] = "missing_gbnf_file"
|
|
description_placeholders["filename"] = filename
|
|
|
|
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
|
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
|
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
|
errors["base"] = "missing_icl_file"
|
|
description_placeholders["filename"] = filename
|
|
|
|
# --- Normalize numeric fields to ints to avoid slice/type errors later ---
|
|
for key in (
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_MAX_TOOL_CALL_ITERATIONS,
|
|
CONF_CONTEXT_LENGTH,
|
|
CONF_MAX_TOKENS,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_AI_TASK_RETRIES,
|
|
):
|
|
if key in user_input:
|
|
user_input[key] = _coerce_int(user_input[key], user_input.get(key) or 0)
|
|
|
|
if len(errors) == 0:
|
|
try:
|
|
# validate input
|
|
schema(user_input)
|
|
self.model_config.update(user_input)
|
|
|
|
return await self.async_step_finish()
|
|
except Exception:
|
|
_LOGGER.exception("An unknown error has occurred!")
|
|
errors["base"] = "unknown"
|
|
|
|
return self.async_show_form(
|
|
step_id="model_parameters",
|
|
data_schema=schema,
|
|
errors=errors,
|
|
description_placeholders=description_placeholders,
|
|
)
|
|
|
|
async def async_step_finish(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> SubentryFlowResult:
|
|
"""Step after model downloading has succeeded."""
|
|
|
|
# Model download completed, create/update the entry with stored config
|
|
if self._is_new:
|
|
return self.async_create_entry(
|
|
title=self.model_config.get(CONF_CHAT_MODEL, "Model"),
|
|
data=self.model_config,
|
|
)
|
|
else:
|
|
return self.async_update_and_abort(
|
|
self._get_entry(), self._get_reconfigure_subentry(), data=self.model_config
|
|
)
|
|
|
|
async def async_step_user(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> SubentryFlowResult:
|
|
"""Handle model selection and configuration step."""
|
|
|
|
# Ensure the parent entry is loaded before allowing subentry edits
|
|
if self._get_entry().state != ConfigEntryState.LOADED:
|
|
return self.async_abort(reason="entry_not_loaded")
|
|
|
|
if not self.model_config:
|
|
self.model_config = {}
|
|
|
|
return await self.async_step_pick_model(user_input)
|
|
|
|
async_step_init = async_step_user
|
|
|
|
async def async_step_reconfigure(
|
|
self, user_input: dict[str, Any] | None = None):
|
|
if not self.model_config:
|
|
self.model_config = dict(self._get_reconfigure_subentry().data)
|
|
|
|
return await self.async_step_model_parameters(user_input)
|