mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
more code cleanup and config_flow fixes
This commit is contained in:
11
TODO.md
11
TODO.md
@@ -50,9 +50,16 @@
|
||||
- [x] split out entity functionality so we can support conversation + ai tasks
|
||||
- [x] fix icl examples to match new tool calling syntax config
|
||||
- [x] set up docker-compose for running all of the various backends
|
||||
- [ ] fix and re-upload all compatible old models (+ upload all original safetensors)
|
||||
- [x] fix the openai responses backend
|
||||
- [ ] config entry migration function?
|
||||
- [ ] config sub-entry implementation
|
||||
- [x] base work
|
||||
- [x] generic openai backend
|
||||
- [x] llamacpp backend
|
||||
- [x] ollama backend
|
||||
- [ ] tailored_openai backend
|
||||
- [ ] generic openai responses backend
|
||||
- [ ] fix and re-upload all compatible old models (+ upload all original safetensors)
|
||||
- [ ] config entry migration function
|
||||
|
||||
## more complicated ideas
|
||||
- [ ] "context requests"
|
||||
|
||||
@@ -21,19 +21,13 @@ from .const import (
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
)
|
||||
from .entity import LocalLLMClient, LocalLLMConfigEntry
|
||||
from .backends.llamacpp import LlamaCppClient
|
||||
@@ -47,6 +41,15 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
|
||||
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
|
||||
BACKEND_TYPE_GENERIC_OPENAI: GenericOpenAIAPIClient,
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES: GenericOpenAIResponsesAPIClient,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI: TextGenerationWebuiClient,
|
||||
BACKEND_TYPE_LLAMA_CPP_SERVER: LlamaCppServerClient,
|
||||
BACKEND_TYPE_OLLAMA: OllamaAPIClient,
|
||||
}
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
@@ -64,26 +67,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
def create_client(backend_type):
|
||||
client_cls = None
|
||||
|
||||
_LOGGER.debug("Creating Local LLM client of type %s", backend_type)
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
client_cls = LlamaCppClient
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
client_cls = GenericOpenAIAPIClient
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
||||
client_cls = GenericOpenAIResponsesAPIClient
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
client_cls = TextGenerationWebuiClient
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_SERVER:
|
||||
client_cls = LlamaCppServerClient
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
client_cls = OllamaAPIClient
|
||||
|
||||
if client_cls is None:
|
||||
raise ValueError(f"Unknown backend type {backend_type}")
|
||||
return client_cls(hass, dict(entry.options))
|
||||
return BACKEND_TO_CLS[backend_type](hass, dict(entry.options))
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
@@ -97,7 +82,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
|
||||
@@ -210,47 +210,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
return response_text, tool_calls
|
||||
|
||||
|
||||
async def _async_validate_generic_openai(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to an OpenAI compatible API server and that the model exists on the remote server
|
||||
|
||||
:param user_input: the input dictionary used to build the connection
|
||||
:return: a tuple of (error message name, exception detail); both can be None
|
||||
"""
|
||||
try:
|
||||
headers = {}
|
||||
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
|
||||
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
format_url(
|
||||
hostname=self.model_config[CONF_HOST],
|
||||
port=self.model_config[CONF_PORT],
|
||||
ssl=self.model_config[CONF_SSL],
|
||||
path=f"/{api_base_path}/models"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
|
||||
models = [ model["id"] for model in models_result["data"] ]
|
||||
|
||||
for model in models:
|
||||
if model == self.model_config[CONF_CHAT_MODEL]:
|
||||
return None, None, []
|
||||
|
||||
return "missing_model_api", None, models
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.info("Connection error was: %s", repr(ex))
|
||||
return "failed_to_connect", ex, []
|
||||
|
||||
# FIXME: this class is mostly broken
|
||||
class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
|
||||
@@ -60,10 +60,12 @@ from custom_components.llama_conversation.entity import LocalLLMClient, TextGene
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
from typing import TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import Llama as LlamaType
|
||||
from llama_cpp import Llama as LlamaType, LlamaGrammar as LlamaGrammarType
|
||||
else:
|
||||
LlamaType = Any
|
||||
LlamaGrammarType = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,7 +83,7 @@ def snapshot_settings(options: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
class LlamaCppClient(LocalLLMClient):
|
||||
llama_cpp_module: Any | None
|
||||
llama_cpp_module: ModuleType | None
|
||||
|
||||
models: dict[str, LlamaType]
|
||||
grammars: dict[str, Any]
|
||||
@@ -111,7 +113,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
async def async_get_available_models(self) -> List[str]:
|
||||
return [] # TODO: copy from config_flow.py
|
||||
return [] # TODO: find available "huggingface_hub" models that have been downloaded
|
||||
|
||||
def _load_model(self, entity_options: dict[str, Any]) -> None:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
@@ -141,7 +143,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
validate_llama_cpp_python_installation()
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
||||
|
||||
_LOGGER.debug(f"Loading model '{model_path}'...")
|
||||
model_settings = snapshot_settings(entity_options)
|
||||
@@ -178,7 +180,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
|
||||
def _load_grammar(self, model_name: str, filename: str) -> Any:
|
||||
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
|
||||
LlamaGrammar: type[LlamaGrammarType] = getattr(self.llama_cpp_module, "LlamaGrammar")
|
||||
_LOGGER.debug(f"Loading grammar {filename}...")
|
||||
try:
|
||||
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
|
||||
@@ -217,7 +219,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
_LOGGER.debug(f"Reloading model '{model_path}'...")
|
||||
model_settings = snapshot_settings(entity_options)
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_ctx=int(model_settings[CONF_CONTEXT_LENGTH]),
|
||||
|
||||
@@ -79,7 +79,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=f"/{api_base_path}/api/tags"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.ok:
|
||||
@@ -97,7 +97,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/api/tags",
|
||||
timeout=5, # quick timeout
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
@@ -135,7 +135,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -3,19 +3,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
|
||||
from homeassistant.data_entry_flow import (
|
||||
AbortFlow,
|
||||
FlowHandler,
|
||||
FlowManager,
|
||||
FlowResult,
|
||||
)
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
@@ -28,7 +24,6 @@ from homeassistant.config_entries import (
|
||||
ConfigEntryState,
|
||||
)
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
@@ -155,15 +150,10 @@ from .const import (
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION
|
||||
)
|
||||
|
||||
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient
|
||||
from .backends.generic_openai import GenericOpenAIAPIClient, GenericOpenAIResponsesAPIClient
|
||||
from .backends.tailored_openai import TextGenerationWebuiClient, LlamaCppServerClient
|
||||
from .backends.ollama import OllamaAPIClient
|
||||
from . import HomeLLMAPI, LocalLLMConfigEntry, LocalLLMClient, BACKEND_TO_CLS
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def is_local_backend(backend):
|
||||
return backend == BACKEND_TYPE_LLAMA_CPP
|
||||
|
||||
def pick_backend_schema(backend_type=None, selected_language=None):
|
||||
return vol.Schema(
|
||||
@@ -263,9 +253,9 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
)
|
||||
elif self.internal_step == "pick_backend":
|
||||
if user_input:
|
||||
local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
|
||||
backend = user_input[CONF_BACKEND_TYPE]
|
||||
self.client_config.update(user_input)
|
||||
if local_backend:
|
||||
if backend == BACKEND_TYPE_LLAMA_CPP:
|
||||
self.installed_version = await self.hass.async_add_executor_job(get_llama_cpp_python_version)
|
||||
_LOGGER.debug(f"installed version: {self.installed_version}")
|
||||
if self.installed_version == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
|
||||
@@ -301,7 +291,10 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
selected_language=self.client_config.get(CONF_SELECTED_LANGUAGE)
|
||||
), errors=errors, last_step=False)
|
||||
elif self.internal_step == "install_local_wheels":
|
||||
if self.install_wheel_task and not self.install_wheel_task.done():
|
||||
if not self.install_wheel_task:
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
if not self.install_wheel_task.done():
|
||||
return self.async_show_progress(
|
||||
progress_task=self.install_wheel_task,
|
||||
step_id="user",
|
||||
@@ -330,16 +323,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
self.client_config.update(user_input)
|
||||
|
||||
# validate remote connections
|
||||
connect_err = True
|
||||
backend = self.client_config[CONF_BACKEND_TYPE]
|
||||
if backend == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
connect_err = await GenericOpenAIAPIClient.async_validate_connection(self.hass, self.client_config)
|
||||
elif backend == BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
||||
connect_err = await GenericOpenAIResponsesAPIClient.async_validate_connection(self.hass, self.client_config)
|
||||
elif backend == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
connect_err = await TextGenerationWebuiClient.async_validate_connection(self.hass, self.client_config)
|
||||
elif backend == BACKEND_TYPE_OLLAMA:
|
||||
connect_err = await OllamaAPIClient.async_validate_connection(self.hass, self.client_config)
|
||||
connect_err = await BACKEND_TO_CLS[self.client_config[CONF_BACKEND_TYPE]].async_validate_connection(self.hass, self.client_config)
|
||||
|
||||
if not connect_err:
|
||||
return await self.async_step_finish()
|
||||
@@ -376,7 +360,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
|
||||
backend = self.client_config[CONF_BACKEND_TYPE]
|
||||
title = "Generic AI Provider"
|
||||
if is_local_backend(backend):
|
||||
if backend == BACKEND_TYPE_LLAMA_CPP:
|
||||
title = f"LLama.cpp (llama-cpp-python v{self.installed_version})"
|
||||
else:
|
||||
host = self.client_config[CONF_HOST]
|
||||
@@ -409,7 +393,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
|
||||
@classmethod
|
||||
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
|
||||
return not is_local_backend(config_entry.options.get(CONF_BACKEND_TYPE))
|
||||
return config_entry.options.get(CONF_BACKEND_TYPE) != BACKEND_TYPE_LLAMA_CPP
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(
|
||||
@@ -440,23 +424,29 @@ class OptionsFlow(BaseOptionsFlow):
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
|
||||
backend_type = self.config_entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
client_config = dict(self.config_entry.options)
|
||||
|
||||
if user_input is not None:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
client_config.update(user_input)
|
||||
|
||||
# FIXME: invoke the static function on the appropriate llm client class
|
||||
# self.config_entry.runtime_data.validate_connection(user_input) # is this correct?
|
||||
# validate remote connections
|
||||
connect_err = await BACKEND_TO_CLS[backend_type].async_validate_connection(self.hass, client_config)
|
||||
|
||||
if len(errors) == 0:
|
||||
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
|
||||
if not connect_err:
|
||||
return self.async_create_entry(data=client_config)
|
||||
else:
|
||||
errors["base"] = "failed_to_connect"
|
||||
description_placeholders["exception"] = str(connect_err)
|
||||
|
||||
schema = remote_connection_schema(
|
||||
backend_type=self.config_entry.options[CONF_BACKEND_TYPE],
|
||||
host=self.config_entry.options.get(CONF_HOST),
|
||||
port=self.config_entry.options.get(CONF_PORT),
|
||||
ssl=self.config_entry.options.get(CONF_SSL),
|
||||
selected_path=self.config_entry.options.get(CONF_GENERIC_OPENAI_PATH)
|
||||
backend_type=backend_type,
|
||||
host=client_config.get(CONF_HOST),
|
||||
port=client_config.get(CONF_PORT),
|
||||
ssl=client_config.get(CONF_SSL),
|
||||
selected_path=client_config.get(CONF_GENERIC_OPENAI_PATH)
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=schema,
|
||||
@@ -513,20 +503,13 @@ def build_prompt_template(selected_language: str, prompt_template_template: str)
|
||||
|
||||
def local_llama_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
parent_options: MappingProxyType[str, Any],
|
||||
options: MappingProxyType[str, Any],
|
||||
language: str,
|
||||
options: dict[str, Any],
|
||||
backend_type: str,
|
||||
subentry_type: str,
|
||||
) -> dict:
|
||||
if not options:
|
||||
options = DEFAULT_OPTIONS
|
||||
|
||||
default_prompt = build_prompt_template(parent_options[CONF_SELECTED_LANGUAGE], DEFAULT_PROMPT)
|
||||
|
||||
# TODO: we need to make this the "model config" i.e. each subentry defines all of the things to define
|
||||
# the model and the parent entry just defines the "connection options" (or llama-cpp-python version)
|
||||
|
||||
# will also need to move the model download steps to the config sub-entry
|
||||
default_prompt = build_prompt_template(language, DEFAULT_PROMPT)
|
||||
|
||||
result: dict = {
|
||||
vol.Optional(
|
||||
@@ -586,7 +569,7 @@ def local_llama_config_option_schema(
|
||||
): str
|
||||
}
|
||||
|
||||
if is_local_backend(backend_type):
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
result.update({
|
||||
vol.Required(
|
||||
CONF_TOP_K,
|
||||
@@ -919,7 +902,6 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
self.model_config: dict[str, Any] = {}
|
||||
self.download_task = None
|
||||
self.download_error = None
|
||||
self.internal_step = "pick_model"
|
||||
|
||||
@property
|
||||
def _is_new(self) -> bool:
|
||||
@@ -933,15 +915,12 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
return entry.runtime_data
|
||||
|
||||
async def async_step_pick_model(
|
||||
self, user_input: dict[str, Any] | None,
|
||||
entry: LocalLLMConfigEntry
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
schema = vol.Schema({})
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
|
||||
if not self.model_config:
|
||||
self.model_config = {}
|
||||
entry = self._get_entry()
|
||||
|
||||
backend_type = entry.options.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
@@ -986,23 +965,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if not model_file:
|
||||
model_name = self.model_config.get(CONF_CHAT_MODEL)
|
||||
if model_name:
|
||||
return await self.async_step_download(entry)
|
||||
return await self.async_step_download()
|
||||
else:
|
||||
errors["base"] = "no_model_name_or_file"
|
||||
|
||||
if os.path.exists(model_file):
|
||||
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
|
||||
self.internal_step = "model_parameters"
|
||||
return await self.async_step_model_parameters(None, entry)
|
||||
return await self.async_step_model_parameters()
|
||||
else:
|
||||
errors["base"] = "missing_model_file"
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file)
|
||||
else:
|
||||
self.internal_step = "model_parameters"
|
||||
return await self.async_step_model_parameters(None, entry)
|
||||
return await self.async_step_model_parameters()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
step_id="pick_model",
|
||||
data_schema=schema,
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
@@ -1010,27 +987,32 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
)
|
||||
|
||||
async def async_step_download(
|
||||
self, entry: LocalLLMConfigEntry
|
||||
self, user_input: dict[str, Any] | None = None,
|
||||
) -> SubentryFlowResult:
|
||||
if not self.download_task:
|
||||
model_name = self.model_config[CONF_CHAT_MODEL]
|
||||
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
|
||||
storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models")
|
||||
self.download_task = self.hass.async_add_executor_job(
|
||||
download_model_from_hf, model_name, quantization_type, storage_folder
|
||||
)
|
||||
|
||||
async def download_task():
|
||||
await self.hass.async_add_executor_job(
|
||||
download_model_from_hf, model_name, quantization_type, storage_folder
|
||||
)
|
||||
|
||||
self.download_task = self.hass.async_create_background_task(
|
||||
download_task(), name="model_download_task")
|
||||
|
||||
return self.async_show_progress(
|
||||
progress_task=self.download_task,
|
||||
step_id="user",
|
||||
step_id="download",
|
||||
progress_action="download",
|
||||
)
|
||||
|
||||
if self.download_task and not self.download_task.done():
|
||||
return self.async_show_progress(
|
||||
progress_task=self.download_task,
|
||||
step_id="user",
|
||||
step_id="download",
|
||||
progress_action="download",
|
||||
)
|
||||
|
||||
@@ -1038,47 +1020,48 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if download_exception:
|
||||
_LOGGER.info("Failed to download model: %s", repr(download_exception))
|
||||
self.download_error = download_exception
|
||||
self.internal_step = "select_local_model"
|
||||
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id="failed")
|
||||
return self.async_show_progress_done(next_step_id="pick_model")
|
||||
else:
|
||||
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
|
||||
self.internal_step = "model_parameters"
|
||||
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id="model_parameters")
|
||||
|
||||
async def async_step_model_parameters(
|
||||
self, user_input: dict[str, Any] | None,
|
||||
entry: LocalLLMConfigEntry,
|
||||
self, user_input: dict[str, Any] | None = None,
|
||||
) -> SubentryFlowResult:
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
entry = self._get_entry()
|
||||
backend_type = entry.options[CONF_BACKEND_TYPE]
|
||||
|
||||
# determine selected language from model config or parent options
|
||||
selected_language = self.model_config.get(
|
||||
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
|
||||
)
|
||||
model_name = self.model_config.get(CONF_CHAT_MODEL, "").lower()
|
||||
if not self.model_config:
|
||||
# determine selected language from model config or parent options
|
||||
selected_language = self.model_config.get(
|
||||
CONF_SELECTED_LANGUAGE, entry.options.get(CONF_SELECTED_LANGUAGE, "en")
|
||||
)
|
||||
model_name = self.model_config.get(CONF_CHAT_MODEL, "").lower()
|
||||
|
||||
selected_default_options = {**DEFAULT_OPTIONS}
|
||||
for key in OPTIONS_OVERRIDES.keys():
|
||||
if key in model_name:
|
||||
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
||||
break
|
||||
selected_default_options = {**DEFAULT_OPTIONS}
|
||||
for key in OPTIONS_OVERRIDES.keys():
|
||||
if key in model_name:
|
||||
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
||||
break
|
||||
|
||||
# Build prompt template using the selected language
|
||||
selected_default_options[CONF_PROMPT] = build_prompt_template(
|
||||
selected_language, selected_default_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
)
|
||||
# Build prompt template using the selected language
|
||||
selected_default_options[CONF_PROMPT] = build_prompt_template(
|
||||
selected_language, str(selected_default_options.get(CONF_PROMPT, DEFAULT_PROMPT))
|
||||
)
|
||||
|
||||
self.model_config = selected_default_options
|
||||
|
||||
schema = vol.Schema(
|
||||
local_llama_config_option_schema(
|
||||
self.hass,
|
||||
entry.options,
|
||||
MappingProxyType(selected_default_options),
|
||||
entry.options[CONF_SELECTED_LANGUAGE],
|
||||
self.model_config,
|
||||
backend_type,
|
||||
self._subentry_type,
|
||||
)
|
||||
@@ -1114,17 +1097,11 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
step_id="model_parameters",
|
||||
data_schema=schema,
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
|
||||
async def async_step_failed(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Step after model downloading has failed."""
|
||||
return self.async_abort(reason="download_failed")
|
||||
|
||||
async def async_step_finish(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@@ -1146,23 +1123,21 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Handle model selection and configuration step."""
|
||||
entry: LocalLLMConfigEntry = self._get_entry()
|
||||
|
||||
# Ensure the parent entry is loaded before allowing subentry edits
|
||||
if entry.state != ConfigEntryState.LOADED:
|
||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||
return self.async_abort(reason="entry_not_loaded")
|
||||
|
||||
if self.internal_step == "pick_model":
|
||||
return await self.async_step_pick_model(user_input, entry)
|
||||
elif self.internal_step == "download":
|
||||
return await self.async_step_download(entry)
|
||||
elif self.internal_step == "model_parameters":
|
||||
return await self.async_step_model_parameters(user_input, entry)
|
||||
else:
|
||||
return self.async_abort(reason="unknown_internal_step")
|
||||
|
||||
if not self.model_config:
|
||||
self.model_config = {}
|
||||
|
||||
return await self.async_step_pick_model(user_input)
|
||||
|
||||
async_step_init = async_step_user
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None):
|
||||
return await self.async_step_model_parameters(user_input, self._get_entry())
|
||||
if not self.model_config:
|
||||
self.model_config = dict(self._get_reconfigure_subentry().data)
|
||||
|
||||
async_step_init = async_step_user
|
||||
return await self.async_step_model_parameters(user_input)
|
||||
@@ -69,7 +69,6 @@ class TextGenerationResult:
|
||||
raise_error: bool = False
|
||||
error_msg: Optional[str] = None
|
||||
|
||||
# TODO: each client needs to support calling multiple models without re-creating the client (llama.cpp will need the most work)
|
||||
class LocalLLMClient:
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
{
|
||||
"config": {
|
||||
"error": {
|
||||
"download_failed": "The download failed to complete: {exception}",
|
||||
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
|
||||
"no_supported_ggufs": "The provided HuggingFace repo does not contain any compatible GGUF files!",
|
||||
"failed_to_connect": "Failed to connect to the remote API: {exception}",
|
||||
"missing_model_api": "The selected model is not provided by this API. The available models have been populated in the dropdown.",
|
||||
"missing_model_file": "The provided file does not exist.",
|
||||
"other_existing_local": "Another model is already loaded locally. Please unload it or configure a remote model.",
|
||||
"unknown": "Unexpected error",
|
||||
"pip_wheel_error": "Pip returned an error while installing the wheel! Please check the Home Assistant logs for more details.",
|
||||
"sys_refresh_caching_enabled": "System prompt refresh must be enabled for prompt caching to work!",
|
||||
"missing_gbnf_file": "The GBNF file was not found: {filename}",
|
||||
"missing_icl_file": "The in context learning example CSV file was not found: {filename}"
|
||||
"pip_wheel_error": "Pip returned an error while installing the wheel! Please check the Home Assistant logs for more details."
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is being downloaded from HuggingFace. This can take a few minutes.",
|
||||
@@ -51,7 +42,6 @@
|
||||
"download_failed": "The download failed to complete: {exception}",
|
||||
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
|
||||
"no_supported_ggufs": "The provided HuggingFace repo does not contain any compatible GGUF files!",
|
||||
"failed_to_connect": "Failed to connect to the remote API: {exception}",
|
||||
"missing_model_api": "The selected model is not provided by this API. The available models have been populated in the dropdown.",
|
||||
"missing_model_file": "The provided file does not exist.",
|
||||
"other_existing_local": "Another model is already loaded locally. Please unload it or configure a remote model.",
|
||||
@@ -63,15 +53,72 @@
|
||||
"progress": {
|
||||
"download": "Please wait while the model is being downloaded from HuggingFace. This can take a few minutes."
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "Successfully updated model options."
|
||||
},
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"huggingface_model": "Model Name"
|
||||
"huggingface_model": "Model Name",
|
||||
"downloaded_model_file": "Local file name",
|
||||
"downloaded_model_quantization": "Downloaded model quantization"
|
||||
},
|
||||
"description": "Select a model to use. \n\n**Models supported out of the box:**\n1. [Home LLM](https://huggingface.co/collections/acon96/home-llm-6618762669211da33bb22c5a): Home 3B & Home 1B\n2. Mistral: [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)\n3. Llama 3: [8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and [70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)",
|
||||
"title": "Pick Model"
|
||||
},
|
||||
"user": {
|
||||
"model_parameters": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
"top_p": "Top P",
|
||||
"min_p": "Min P",
|
||||
"typical_p": "Typical P",
|
||||
"request_timeout": "Remote Request Timeout (seconds)",
|
||||
"ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)",
|
||||
"ollama_json_mode": "(ollama) JSON Output Mode",
|
||||
"extra_attributes_to_expose": "Additional attribute to expose in the context",
|
||||
"enable_flash_attention": "Enable Flash Attention",
|
||||
"gbnf_grammar": "Enable GBNF Grammar",
|
||||
"gbnf_grammar_file": "GBNF Grammar Filename",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "(text-generation-webui) Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
"refresh_prompt_per_turn": "Refresh System Prompt Every Turn",
|
||||
"remember_conversation": "Remember conversation",
|
||||
"remember_num_interactions": "Number of past interactions to remember",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"in_context_examples_file": "In context learning examples CSV filename",
|
||||
"num_in_context_examples": "Number of ICL examples to generate",
|
||||
"text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name",
|
||||
"text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode",
|
||||
"prompt_caching": "Enable Prompt Caching",
|
||||
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
|
||||
"context_length": "Context Length",
|
||||
"batch_size": "(llama.cpp) Batch Size",
|
||||
"n_threads": "(llama.cpp) Thread Count",
|
||||
"n_batch_threads": "(llama.cpp) Batch Thread Count",
|
||||
"thinking_prefix": "Reasoning Content Prefix",
|
||||
"thinking_suffix": "Reasoning Content Suffix",
|
||||
"tool_call_prefix": "Tool Call Prefix",
|
||||
"tool_call_suffix": "Tool Call Suffix",
|
||||
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
|
||||
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
|
||||
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
|
||||
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below."
|
||||
},
|
||||
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
|
||||
"title": "Configure the selected model"
|
||||
},
|
||||
"reconfigure": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
@@ -141,6 +188,10 @@
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"failed_to_connect": "Failed to connect to the remote API: {exception}",
|
||||
"unknown": "Unexpected error"
|
||||
}
|
||||
},
|
||||
"selector": {
|
||||
|
||||
Reference in New Issue
Block a user