mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
config sub-entry flow is mostly working now
This commit is contained in:
1
TODO.md
1
TODO.md
@@ -6,6 +6,7 @@
|
||||
- [ ] new model based on gemma3 270m
|
||||
- [ ] support AI task API
|
||||
- [ ] move llamacpp to a separate process because of all the crashing
|
||||
- [ ] optional sampling parameters in options panel (don't pass to backend if not set)
|
||||
- [x] support new LLM APIs
|
||||
- rewrite how services are called
|
||||
- handle no API selected
|
||||
|
||||
@@ -169,8 +169,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
|
||||
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = entity_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
endpoint = "/chat/completions"
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None) -> Tuple[Optional[str], Optional[List]]:
|
||||
|
||||
@@ -13,7 +13,7 @@ from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_D
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
@@ -81,7 +81,7 @@ def snapshot_settings(options: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
class LlamaCppClient(LocalLLMClient):
|
||||
llama_cpp_module: Any
|
||||
llama_cpp_module: Any | None
|
||||
|
||||
models: dict[str, LlamaType]
|
||||
grammars: dict[str, Any]
|
||||
@@ -96,6 +96,20 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
_attr_supports_streaming = True
|
||||
|
||||
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
||||
super().__init__(hass, client_options)
|
||||
|
||||
self.llama_cpp_module = None
|
||||
self.models = {}
|
||||
self.grammars = {}
|
||||
self.loaded_model_settings = {}
|
||||
|
||||
self.remove_prompt_caching_listener = None
|
||||
self.last_cache_prime = 0.0
|
||||
self.last_updated_entities = {}
|
||||
self.cache_refresh_after_cooldown = False
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
async def async_get_available_models(self) -> List[str]:
|
||||
return [] # TODO: copy from config_flow.py
|
||||
|
||||
@@ -103,6 +117,10 @@ class LlamaCppClient(LocalLLMClient):
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
model_path = entity_options.get(CONF_DOWNLOADED_MODEL_FILE, "")
|
||||
|
||||
if model_name in self.models:
|
||||
_LOGGER.info("Model %s is already loaded", model_name)
|
||||
return
|
||||
|
||||
_LOGGER.info("Using model file '%s'", model_path)
|
||||
|
||||
if not model_path or not os.path.isfile(model_path):
|
||||
@@ -144,13 +162,6 @@ class LlamaCppClient(LocalLLMClient):
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
|
||||
self.remove_prompt_caching_listener = None
|
||||
self.last_cache_prime = 0.0
|
||||
self.last_updated_entities = {}
|
||||
self.cache_refresh_after_cooldown = False
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
|
||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@callback
|
||||
async def enable_caching_after_startup(_now) -> None:
|
||||
@@ -222,10 +233,10 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
current_grammar = entity_options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
|
||||
if model_name not in self.grammars or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
|
||||
self._load_grammar(model_name, current_grammar)
|
||||
else:
|
||||
self.grammar = None
|
||||
self.grammars[model_name] = None
|
||||
|
||||
if entity_options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
self._set_prompt_caching(entity_options, enabled=True)
|
||||
@@ -342,17 +353,17 @@ class LlamaCppClient(LocalLLMClient):
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(entity_options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
min_p = entity_options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
||||
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
grammar = self.grammar if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
grammar = self.grammars.get(model_name) if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug("Priming model cache via chat completion API...")
|
||||
|
||||
try:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
|
||||
# avoid strict typing issues from the llama-cpp-python bindings
|
||||
self.models[model_name].create_chat_completion(
|
||||
messages,
|
||||
@@ -406,6 +417,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
min_p = entity_options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
||||
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
grammar = self.grammars.get(model_name) if entity_options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug(f"Options: {entity_options}")
|
||||
|
||||
@@ -442,7 +454,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
max_tokens=max_tokens,
|
||||
grammar=self.grammar,
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,8 +138,6 @@ from .const import (
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE,
|
||||
DEFAULT_LLAMACPP_THREAD_COUNT,
|
||||
DEFAULT_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF_SETUP,
|
||||
BACKEND_TYPE_LLAMA_EXISTING_SETUP,
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
@@ -165,7 +163,7 @@ from .backends.ollama import OllamaAPIClient
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def is_local_backend(backend):
|
||||
return backend in [BACKEND_TYPE_LLAMA_EXISTING_SETUP, BACKEND_TYPE_LLAMA_HF_SETUP, BACKEND_TYPE_LLAMA_CPP]
|
||||
return backend == BACKEND_TYPE_LLAMA_CPP
|
||||
|
||||
def pick_backend_schema(backend_type=None, selected_language=None):
|
||||
return vol.Schema(
|
||||
@@ -442,23 +440,17 @@ class OptionsFlow(BaseOptionsFlow):
|
||||
)
|
||||
|
||||
|
||||
def STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file=None):
|
||||
def STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file=None, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
|
||||
}
|
||||
)
|
||||
|
||||
def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_quantization=None, available_quantizations=None):
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
|
||||
vol.Optional(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
|
||||
options=RECOMMENDED_CHAT_MODELS,
|
||||
custom_value=True,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
vol.Required(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
|
||||
vol.Optional(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
|
||||
vol.Optional(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -852,7 +844,7 @@ def local_llama_config_option_schema(
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_MIN_P,
|
||||
CONF_TYPICAL_P
|
||||
CONF_TYPICAL_P,
|
||||
CONF_TOP_K,
|
||||
# tool calling/reasoning
|
||||
CONF_THINKING_PREFIX,
|
||||
@@ -927,10 +919,8 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
self.model_config = {}
|
||||
|
||||
backend_type = entry.options.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
if backend_type == BACKEND_TYPE_LLAMA_HF_SETUP:
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA()
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING_SETUP:
|
||||
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA()
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA()
|
||||
else:
|
||||
schema = STEP_REMOTE_MODEL_SELECTION_DATA_SCHEMA(await entry.runtime_data.async_get_available_models())
|
||||
|
||||
@@ -940,7 +930,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
if len(available_quants) == 0:
|
||||
errors["base"] = "no_supported_ggufs"
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
||||
chat_model=self.model_config[CONF_CHAT_MODEL],
|
||||
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
|
||||
)
|
||||
@@ -949,7 +939,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
description_placeholders["missing"] = self.download_error.missing_quant
|
||||
description_placeholders["available"] = ", ".join(self.download_error.available_quants)
|
||||
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
||||
chat_model=self.model_config[CONF_CHAT_MODEL],
|
||||
downloaded_model_quantization=self.download_error.available_quants[0],
|
||||
available_quantizations=available_quants,
|
||||
@@ -957,7 +947,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
else:
|
||||
errors["base"] = "download_failed"
|
||||
description_placeholders["exception"] = str(self.download_error)
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(
|
||||
chat_model=self.model_config[CONF_CHAT_MODEL],
|
||||
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
|
||||
)
|
||||
@@ -966,17 +956,24 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
self.model_config.update(user_input)
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_HF_SETUP:
|
||||
return await self.async_step_download(entry)
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING_SETUP:
|
||||
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
model_file = self.model_config.get(CONF_DOWNLOADED_MODEL_FILE, "")
|
||||
if not model_file:
|
||||
model_name = self.model_config.get(CONF_CHAT_MODEL)
|
||||
if model_name:
|
||||
return await self.async_step_download(entry)
|
||||
else:
|
||||
errors["base"] = "no_model_name_or_file"
|
||||
|
||||
if os.path.exists(model_file):
|
||||
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
|
||||
self.internal_step = "model_parameters"
|
||||
return await self.async_step_model_parameters(None, entry)
|
||||
else:
|
||||
errors["base"] = "missing_model_file"
|
||||
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA(model_file)
|
||||
else:
|
||||
self.internal_step = "model_parameters"
|
||||
return await self.async_step_model_parameters(None, entry)
|
||||
|
||||
return self.async_show_form(
|
||||
@@ -1017,12 +1014,15 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
_LOGGER.info("Failed to download model: %s", repr(download_exception))
|
||||
self.download_error = download_exception
|
||||
self.internal_step = "select_local_model"
|
||||
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id="failed")
|
||||
else:
|
||||
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
|
||||
self.internal_step = "model_parameters"
|
||||
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id="finish")
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id="model_parameters")
|
||||
|
||||
async def async_step_model_parameters(
|
||||
self, user_input: dict[str, Any] | None,
|
||||
@@ -1082,9 +1082,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
try:
|
||||
# validate input
|
||||
schema(user_input)
|
||||
|
||||
self.model_config.update(user_input)
|
||||
|
||||
return await self.async_step_finish()
|
||||
except Exception:
|
||||
_LOGGER.exception("An unknown error has occurred!")
|
||||
|
||||
@@ -104,8 +104,8 @@ DEFAULT_TEMPERATURE = 0.1
|
||||
CONF_REQUEST_TIMEOUT = "request_timeout"
|
||||
DEFAULT_REQUEST_TIMEOUT = 90
|
||||
CONF_BACKEND_TYPE = "model_backend"
|
||||
BACKEND_TYPE_LLAMA_HF_SETUP = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING_SETUP = "llama_cpp_existing"
|
||||
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
|
||||
BACKEND_TYPE_LLAMA_CPP = "llama_cpp_python"
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
|
||||
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Defines the various LLM Backend Agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
import logging
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, ConversationEntity
|
||||
from homeassistant.components.conversation.models import AbstractConversationAgent
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
@@ -21,20 +23,31 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry, asy
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
for subentry in entry.subentries.values():
|
||||
if subentry.subentry_type != "conversation":
|
||||
if subentry.subentry_type != conversation.DOMAIN:
|
||||
continue
|
||||
|
||||
# create one agent entity per conversation subentry
|
||||
agent_entity = LocalLLMAgent(hass, entry, entry.runtime_data, subentry)
|
||||
agent_entity = LocalLLMAgent(hass, entry, subentry, entry.runtime_data)
|
||||
|
||||
# make sure model is loaded
|
||||
await entry.runtime_data._async_load_model(dict(subentry.data))
|
||||
|
||||
# register the agent entity
|
||||
async_add_entities([agent_entity])
|
||||
|
||||
return True
|
||||
|
||||
class LocalLLMAgent(LocalLLMEntity, ConversationEntity, AbstractConversationAgent):
|
||||
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent, LocalLLMEntity):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry, subentry: ConfigSubentry, client: LocalLLMClient) -> None:
|
||||
super().__init__(hass, entry, subentry, client)
|
||||
|
||||
if subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
@@ -45,6 +58,11 @@ class LocalLLMAgent(LocalLLMEntity, ConversationEntity, AbstractConversationAgen
|
||||
conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
async def async_process(
|
||||
self, user_input: ConversationInput
|
||||
) -> ConversationResult:
|
||||
|
||||
@@ -130,6 +130,16 @@ class LocalLLMClient:
|
||||
await self.hass.async_add_executor_job(
|
||||
self._load_model, entity_options
|
||||
)
|
||||
|
||||
def _unload_model(self, entity_options: dict[str, Any]) -> None:
|
||||
"""Unload the model to free up space on the backend. Implemented by sub-classes"""
|
||||
pass
|
||||
|
||||
async def _async_unload_model(self, entity_options: dict[str, Any]) -> None:
|
||||
"""Default implementation is to call _unload_model() which probably does blocking stuff"""
|
||||
await self.hass.async_add_executor_job(
|
||||
self._unload_model, entity_options
|
||||
)
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator for streaming responses. Subclasses should implement."""
|
||||
@@ -632,10 +642,6 @@ class LocalLLMEntity(entity.Entity):
|
||||
self.subentry_id = subentry.subentry_id
|
||||
self.client = client
|
||||
|
||||
if subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
def handle_reload(self):
|
||||
self.client._update_options(self.runtime_options)
|
||||
|
||||
Reference in New Issue
Block a user