work through todos

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:49 -04:00
parent 18a34a3d5b
commit bfc6a5a753
5 changed files with 92 additions and 156 deletions

View File

@@ -50,16 +50,15 @@
- [x] split out entity functionality so we can support conversation + ai tasks
- [x] fix icl examples to match new tool calling syntax config
- [x] set up docker-compose for running all of the various backends
- [x] fix the openai responses backend
- [ ] config sub-entry implementation
- [x] config sub-entry implementation
- [x] base work
- [x] generic openai backend
- [x] llamacpp backend
- [x] ollama backend
- [ ] tailored_openai backend
- [ ] generic openai responses backend
- [x] tailored_openai backend
- [x] generic openai responses backend
- [ ] fix and re-upload all compatible old models (+ upload all original safetensors)
- [ ] config entry migration function
- [x] config entry migration function
## more complicated ideas
- [ ] "context requests"

View File

@@ -5,7 +5,7 @@ import logging
from typing import Final
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import ATTR_ENTITY_ID, Platform
from homeassistant.const import ATTR_ENTITY_ID, Platform, CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, llm, device_registry as dr, entity_registry as er
from homeassistant.util.json import JsonObjectType
@@ -21,6 +21,19 @@ from .const import (
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
CONF_SELECTED_LANGUAGE,
CONF_OPENAI_API_KEY,
CONF_GENERIC_OPENAI_PATH,
CONF_CHAT_MODEL, CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_FILE, CONF_REQUEST_TIMEOUT, CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, CONF_MAX_TOKENS,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_THINKING_PREFIX, CONF_THINKING_SUFFIX, CONF_TOOL_CALL_PREFIX, CONF_TOOL_CALL_SUFFIX,
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL, CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE, CONF_LLAMACPP_THREAD_COUNT, CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE,
CONF_TEXT_GEN_WEBUI_PRESET, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_OLLAMA_JSON_MODE, CONF_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_CPP,
BACKEND_TYPE_TEXT_GEN_WEBUI,
@@ -89,10 +102,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
hass.data[DOMAIN].pop(entry.entry_id)
return True
# TODO: split out which options are per-model and which ones are conversation-specific
# and only migrate the conversation-specific ones to the subentry
ENTRY_KEYS = []
SUBENTRY_KEYS = []
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
"""Migrate old entry."""
@@ -102,88 +111,57 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
if config_entry.version == 1:
_LOGGER.error("Cannot upgrade models that were created prior to v0.3. Please delete and re-create them.")
return False
# If already at or above the target version nothing to do
if config_entry.version >= 3:
_LOGGER.debug("Entry already migrated (version %s)", config_entry.version)
return True
# Migrate each existing config entry to use subentries for conversations
# We will create a conversation subentry using the entry.options plus any
# model identifier stored in entry.data (CONF_CHAT_MODEL / CONF_DOWNLOADED_MODEL_FILE)
# Migrate the config_entry to be an entry + sub-entry
if config_entry.version == 2:
ENTRY_DATA_KEYS = [CONF_BACKEND_TYPE, CONF_SELECTED_LANGUAGE]
ENTRY_OPTIONS_KEYS = [CONF_HOST, CONF_PORT, CONF_SSL, CONF_OPENAI_API_KEY, CONF_GENERIC_OPENAI_PATH]
SUBENTRY_KEYS = [
CONF_CHAT_MODEL, CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_FILE, CONF_REQUEST_TIMEOUT, CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_LLM_HASS_API, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, CONF_MAX_TOKENS,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_THINKING_PREFIX, CONF_THINKING_SUFFIX, CONF_TOOL_CALL_PREFIX, CONF_TOOL_CALL_SUFFIX,
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL, CONF_CONTEXT_LENGTH,
CONF_LLAMACPP_BATCH_SIZE, CONF_LLAMACPP_THREAD_COUNT, CONF_LLAMACPP_BATCH_THREAD_COUNT,
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE,
CONF_TEXT_GEN_WEBUI_PRESET, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_ENABLE_LEGACY_TOOL_CALLING,
CONF_OLLAMA_JSON_MODE, CONF_OLLAMA_KEEP_ALIVE_MIN
]
entries = sorted(
hass.config_entries.async_entries(DOMAIN),
key=lambda e: e.disabled_by is not None,
)
# Build entry data/options & subentry data from existing options and model info
source_data = {**config_entry.data}
source_data.update(config_entry.options)
for entry in entries:
# Skip entries that already have subentries
if entry.subentries:
continue
# Build subentry data from existing options and model info
subentry_data = { k: v for k, v in (entry.options or {}).items() if k in SUBENTRY_KEYS }
entry_data = { k: v for k, v in (entry.data or {}).items() if k in ENTRY_KEYS }
entry_data = { k: v for k, v in source_data.items() if k in ENTRY_DATA_KEYS }
entry_options = { k: v for k, v in source_data.items() if k in ENTRY_OPTIONS_KEYS }
subentry_data = { k: v for k, v in source_data.items() if k in SUBENTRY_KEYS }
subentry = ConfigSubentry(
data=MappingProxyType(subentry_data),
subentry_type="conversation",
title=entry.title,
title=config_entry.title, # FIXME: should be the "new" name format
unique_id=None,
)
hass.config_entries.async_add_subentry(entry, subentry)
# don't attempt to create duplicate llama.cpp backends
if entry_data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
all_entries = hass.config_entries.async_entries(DOMAIN)
for entry in all_entries:
if entry.version < 3 or entry.data[CONF_BACKEND_TYPE] != BACKEND_TYPE_LLAMA_CPP:
continue
# Move entity/device registry associations to the new subentry where applicable
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
await hass.config_entries.async_remove(config_entry.entry_id)
config_entry = entry
break
conversation_entity_id = entity_registry.async_get_entity_id(
"conversation",
DOMAIN,
entry.entry_id,
)
device = device_registry.async_get_device(identifiers={(DOMAIN, entry.entry_id)})
# create sub-entry
hass.config_entries.async_add_subentry(config_entry, subentry)
if conversation_entity_id is not None:
conversation_entity_entry = entity_registry.entities[conversation_entity_id]
entity_disabled_by = conversation_entity_entry.disabled_by
# Keep a sensible disabled flag when migrating
if (
entity_disabled_by is er.RegistryEntryDisabler.CONFIG_ENTRY
and not all(e.disabled_by is not None for e in entries if e.entry_id != entry.entry_id)
):
entity_disabled_by = (
er.RegistryEntryDisabler.DEVICE if device else er.RegistryEntryDisabler.USER
)
entity_registry.async_update_entity(
conversation_entity_id,
config_entry_id=entry.entry_id,
config_subentry_id=subentry.subentry_id,
disabled_by=entity_disabled_by,
new_unique_id=subentry.subentry_id,
)
# update the parent entry
hass.config_entries.async_update_entry(config_entry, data=entry_data, options=entry_options, version=3)
if device is not None:
# Adjust device registry identifiers to point to the subentry
device_disabled_by = device.disabled_by
if (
device.disabled_by is dr.DeviceEntryDisabler.CONFIG_ENTRY
):
device_disabled_by = dr.DeviceEntryDisabler.USER
device_registry.async_update_device(
device.id,
disabled_by=device_disabled_by,
new_identifiers={(DOMAIN, subentry.subentry_id)},
add_config_subentry_id=subentry.subentry_id,
add_config_entry_id=entry.entry_id,
)
# Update the parent entry to remove model-level fields and clear options
hass.config_entries.async_update_entry(entry, data=entry_data, options={}, version=3)
_LOGGER.debug("Migration to subentries complete")
_LOGGER.debug("Migration to subentries complete")
return True

View File

@@ -45,7 +45,6 @@ class GenericOpenAIAPIClient(LocalLLMClient):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
api_host: str
api_base_path: str
api_key: str
_attr_supports_streaming = True
@@ -78,7 +77,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
ssl=user_input[CONF_SSL],
path=f"/{api_base_path}/models"
),
timeout=5, # quick timeout
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
if response.ok:
@@ -96,7 +95,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
session = async_get_clientsession(self.hass)
async with session.get(
f"{self.api_host}/models",
timeout=5, # quick timeout
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
response.raise_for_status()
@@ -212,7 +211,6 @@ class GenericOpenAIAPIClient(LocalLLMClient):
return response_text, tool_calls
# FIXME: this class is mostly broken
class GenericOpenAIResponsesAPIClient(LocalLLMClient):
"""Implements the OpenAPI-compatible Responses API backend."""
@@ -224,21 +222,21 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
_last_response_id: str | None = None
_last_response_id_time: datetime.datetime | None = None
async def _async_load_model(self, entry: ConfigEntry) -> None:
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options)
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
hostname=client_options[CONF_HOST],
port=client_options[CONF_PORT],
ssl=client_options[CONF_SSL],
path="/" + client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
def _responses_params(self, conversation: List[conversation.Content], api_base_path: str) -> Tuple[str, Dict[str, Any]]:
def _responses_params(self, conversation: List[conversation.Content], entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/responses"
endpoint = "/responses"
# Find the last user message in the conversation and use its content as the input
input_text: str | None = None
for msg in reversed(conversation):
@@ -256,9 +254,9 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
request_params["input"] = input_text
# Assign previous_response_id if relevant
if self._last_response_id and self._last_response_id_time and self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
if self._last_response_id and self._last_response_id_time and entity_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
# If the last response was generated recently, use it as a context
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=self.entry.options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=entity_options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
if last_conversation_age < configured_memory_time:
@@ -303,7 +301,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
if response_json["status"] != "completed":
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> str | None:
def _extract_response(self, response_json: dict) -> str | None:
self._validate_response_payload(response_json)
self._check_response_status(response_json)
@@ -348,11 +346,12 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
entity_options: dict[str, Any]) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
model_name = entity_options.get(CONF_CHAT_MODEL)
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
endpoint, additional_params = self._responses_params(conversation, api_base_path=entity_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
endpoint, additional_params = self._responses_params(conversation, entity_options)
request_params: Dict[str, Any] = {
"model": self.model_name,
"model": model_name,
}
request_params.update(additional_params)
@@ -374,7 +373,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
response_json = await response.json()
try:
text = self._extract_response(response_json, llm_api, user_input)
text = self._extract_response(response_json)
return TextGenerationResult(response=text, response_streamed=False)
except Exception as err:
_LOGGER.exception("Failed to parse Responses API payload: %s", err)

View File

@@ -4,23 +4,22 @@ from __future__ import annotations
import logging
import os
from typing import Optional, Tuple, Dict, List, Any
from dataclasses import dataclass
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.llama_conversation.const import (
CONF_MAX_TOKENS,
CONF_CHAT_MODEL,
CONF_TOP_K,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
DEFAULT_TOP_K,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
@@ -37,12 +36,15 @@ from custom_components.llama_conversation.backends.generic_openai import Generic
_LOGGER = logging.getLogger(__name__)
class TextGenerationWebuiClient(GenericOpenAIAPIClient):
admin_key: str
admin_key: Optional[str]
async def _async_load_model(self, entry: ConfigEntry) -> None:
await super()._async_load_model(entry)
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options)
self.admin_key = client_options.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)
async def _async_load_model(self, entity_options: dict[str, Any]) -> None:
model_name = entity_options.get(CONF_CHAT_MODEL)
try:
headers = {}
session = async_get_clientsession(self.hass)
@@ -58,18 +60,16 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
currently_loaded_result = await response.json()
loaded_model = currently_loaded_result["model_name"]
if loaded_model == self.model_name:
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
if loaded_model == model_name:
_LOGGER.info(f"Model {model_name} is already loaded on the remote backend.")
return
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
_LOGGER.info(f"Model is not {model_name} loaded on the remote backend. Loading it now...")
async with session.post(
f"{self.api_host}/v1/internal/model/load",
json={
"model_name": self.model_name,
# TODO: expose arguments to the user in home assistant UI
# "args": {},
"model_name": model_name,
},
headers=headers
) as response:
@@ -98,55 +98,14 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
return endpoint, request_params
async def _async_validate_text_generation_webui(self, user_input: dict) -> tuple:
"""
Validates a connection to text-generation-webui and that the model exists on the remote server
:param user_input: the input dictionary used to build the connection
:return: a tuple of (error message name, exception detail); both can be None
"""
try:
headers = {}
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
session = async_get_clientsession(self.hass)
async with session.get(
format_url(
hostname=self.model_config[CONF_HOST],
port=self.model_config[CONF_PORT],
ssl=self.model_config[CONF_SSL],
path="/v1/internal/model/list"
),
timeout=5, # quick timeout
headers=headers
) as response:
response.raise_for_status()
models = await response.json()
for model in models["model_names"]:
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
return None, None, []
return "missing_model_api", None, models["model_names"]
except Exception as ex:
_LOGGER.info("Connection error was: %s", repr(ex))
return "failed_to_connect", ex, []
class LlamaCppServerClient(GenericOpenAIAPIClient):
grammar: str
async def _async_load_model(self, entry: ConfigEntry):
await super()._async_load_model(entry)
def __init__(self, hass: HomeAssistant, client_options: Dict[str, Any]):
super().__init__(hass, client_options)
return await self.hass.async_add_executor_job(
self._load_model, entry
)
def _load_model(self, entry: ConfigEntry):
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
grammar_file_name = client_options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), grammar_file_name)) as f:
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:

View File

@@ -217,6 +217,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
"""Handle a config flow for Local LLM Conversation."""
VERSION = 3
MINOR_VERSION = 0
install_wheel_task = None
install_wheel_error = None