mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
work through todos
This commit is contained in:
9
TODO.md
9
TODO.md
@@ -50,16 +50,15 @@
|
||||
- [x] split out entity functionality so we can support conversation + ai tasks
|
||||
- [x] fix icl examples to match new tool calling syntax config
|
||||
- [x] set up docker-compose for running all of the various backends
|
||||
- [x] fix the openai responses backend
|
||||
- [ ] config sub-entry implementation
|
||||
- [x] config sub-entry implementation
|
||||
- [x] base work
|
||||
- [x] generic openai backend
|
||||
- [x] llamacpp backend
|
||||
- [x] ollama backend
|
||||
- [ ] tailored_openai backend
|
||||
- [ ] generic openai responses backend
|
||||
- [x] tailored_openai backend
|
||||
- [x] generic openai responses backend
|
||||
- [ ] fix and re-upload all compatible old models (+ upload all original safetensors)
|
||||
- [ ] config entry migration function
|
||||
- [x] config entry migration function
|
||||
|
||||
## more complicated ideas
|
||||
- [ ] "context requests"
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform, CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv, llm, device_registry as dr, entity_registry as er
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
@@ -21,6 +21,19 @@ from .const import (
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
CONF_SELECTED_LANGUAGE,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_CHAT_MODEL, CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_FILE, CONF_REQUEST_TIMEOUT, CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, CONF_MAX_TOKENS,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_THINKING_PREFIX, CONF_THINKING_SUFFIX, CONF_TOOL_CALL_PREFIX, CONF_TOOL_CALL_SUFFIX,
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL, CONF_CONTEXT_LENGTH,
|
||||
CONF_LLAMACPP_BATCH_SIZE, CONF_LLAMACPP_THREAD_COUNT, CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_OLLAMA_JSON_MODE, CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_CPP,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
@@ -89,10 +102,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return True
|
||||
|
||||
# TODO: split out which options are per-model and which ones are conversation-specific
|
||||
# and only migrate the conversation-specific ones to the subentry
|
||||
ENTRY_KEYS = []
|
||||
SUBENTRY_KEYS = []
|
||||
|
||||
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
|
||||
"""Migrate old entry."""
|
||||
@@ -102,88 +111,57 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
|
||||
if config_entry.version == 1:
|
||||
_LOGGER.error("Cannot upgrade models that were created prior to v0.3. Please delete and re-create them.")
|
||||
return False
|
||||
|
||||
# If already at or above the target version nothing to do
|
||||
if config_entry.version >= 3:
|
||||
_LOGGER.debug("Entry already migrated (version %s)", config_entry.version)
|
||||
return True
|
||||
|
||||
# Migrate each existing config entry to use subentries for conversations
|
||||
# We will create a conversation subentry using the entry.options plus any
|
||||
# model identifier stored in entry.data (CONF_CHAT_MODEL / CONF_DOWNLOADED_MODEL_FILE)
|
||||
# Migrate the config_entry to be an entry + sub-entry
|
||||
if config_entry.version == 2:
|
||||
ENTRY_DATA_KEYS = [CONF_BACKEND_TYPE, CONF_SELECTED_LANGUAGE]
|
||||
ENTRY_OPTIONS_KEYS = [CONF_HOST, CONF_PORT, CONF_SSL, CONF_OPENAI_API_KEY, CONF_GENERIC_OPENAI_PATH]
|
||||
SUBENTRY_KEYS = [
|
||||
CONF_CHAT_MODEL, CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_FILE, CONF_REQUEST_TIMEOUT, CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_LLM_HASS_API, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, CONF_MAX_TOKENS,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_THINKING_PREFIX, CONF_THINKING_SUFFIX, CONF_TOOL_CALL_PREFIX, CONF_TOOL_CALL_SUFFIX,
|
||||
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL, CONF_CONTEXT_LENGTH,
|
||||
CONF_LLAMACPP_BATCH_SIZE, CONF_LLAMACPP_THREAD_COUNT, CONF_LLAMACPP_BATCH_THREAD_COUNT,
|
||||
CONF_LLAMACPP_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
CONF_OLLAMA_JSON_MODE, CONF_OLLAMA_KEEP_ALIVE_MIN
|
||||
]
|
||||
|
||||
entries = sorted(
|
||||
hass.config_entries.async_entries(DOMAIN),
|
||||
key=lambda e: e.disabled_by is not None,
|
||||
)
|
||||
# Build entry data/options & subentry data from existing options and model info
|
||||
source_data = {**config_entry.data}
|
||||
source_data.update(config_entry.options)
|
||||
|
||||
for entry in entries:
|
||||
# Skip entries that already have subentries
|
||||
if entry.subentries:
|
||||
continue
|
||||
|
||||
# Build subentry data from existing options and model info
|
||||
subentry_data = { k: v for k, v in (entry.options or {}).items() if k in SUBENTRY_KEYS }
|
||||
entry_data = { k: v for k, v in (entry.data or {}).items() if k in ENTRY_KEYS }
|
||||
entry_data = { k: v for k, v in source_data.items() if k in ENTRY_DATA_KEYS }
|
||||
entry_options = { k: v for k, v in source_data.items() if k in ENTRY_OPTIONS_KEYS }
|
||||
subentry_data = { k: v for k, v in source_data.items() if k in SUBENTRY_KEYS }
|
||||
|
||||
subentry = ConfigSubentry(
|
||||
data=MappingProxyType(subentry_data),
|
||||
subentry_type="conversation",
|
||||
title=entry.title,
|
||||
title=config_entry.title, # FIXME: should be the "new" name format
|
||||
unique_id=None,
|
||||
)
|
||||
|
||||
hass.config_entries.async_add_subentry(entry, subentry)
|
||||
# don't attempt to create duplicate llama.cpp backends
|
||||
if entry_data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
|
||||
all_entries = hass.config_entries.async_entries(DOMAIN)
|
||||
for entry in all_entries:
|
||||
if entry.version < 3 or entry.data[CONF_BACKEND_TYPE] != BACKEND_TYPE_LLAMA_CPP:
|
||||
continue
|
||||
|
||||
# Move entity/device registry associations to the new subentry where applicable
|
||||
entity_registry = er.async_get(hass)
|
||||
device_registry = dr.async_get(hass)
|
||||
await hass.config_entries.async_remove(config_entry.entry_id)
|
||||
config_entry = entry
|
||||
break
|
||||
|
||||
conversation_entity_id = entity_registry.async_get_entity_id(
|
||||
"conversation",
|
||||
DOMAIN,
|
||||
entry.entry_id,
|
||||
)
|
||||
device = device_registry.async_get_device(identifiers={(DOMAIN, entry.entry_id)})
|
||||
# create sub-entry
|
||||
hass.config_entries.async_add_subentry(config_entry, subentry)
|
||||
|
||||
if conversation_entity_id is not None:
|
||||
conversation_entity_entry = entity_registry.entities[conversation_entity_id]
|
||||
entity_disabled_by = conversation_entity_entry.disabled_by
|
||||
# Keep a sensible disabled flag when migrating
|
||||
if (
|
||||
entity_disabled_by is er.RegistryEntryDisabler.CONFIG_ENTRY
|
||||
and not all(e.disabled_by is not None for e in entries if e.entry_id != entry.entry_id)
|
||||
):
|
||||
entity_disabled_by = (
|
||||
er.RegistryEntryDisabler.DEVICE if device else er.RegistryEntryDisabler.USER
|
||||
)
|
||||
entity_registry.async_update_entity(
|
||||
conversation_entity_id,
|
||||
config_entry_id=entry.entry_id,
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
disabled_by=entity_disabled_by,
|
||||
new_unique_id=subentry.subentry_id,
|
||||
)
|
||||
# update the parent entry
|
||||
hass.config_entries.async_update_entry(config_entry, data=entry_data, options=entry_options, version=3)
|
||||
|
||||
if device is not None:
|
||||
# Adjust device registry identifiers to point to the subentry
|
||||
device_disabled_by = device.disabled_by
|
||||
if (
|
||||
device.disabled_by is dr.DeviceEntryDisabler.CONFIG_ENTRY
|
||||
):
|
||||
device_disabled_by = dr.DeviceEntryDisabler.USER
|
||||
device_registry.async_update_device(
|
||||
device.id,
|
||||
disabled_by=device_disabled_by,
|
||||
new_identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
add_config_subentry_id=subentry.subentry_id,
|
||||
add_config_entry_id=entry.entry_id,
|
||||
)
|
||||
|
||||
# Update the parent entry to remove model-level fields and clear options
|
||||
hass.config_entries.async_update_entry(entry, data=entry_data, options={}, version=3)
|
||||
|
||||
_LOGGER.debug("Migration to subentries complete")
|
||||
_LOGGER.debug("Migration to subentries complete")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
api_host: str
|
||||
api_base_path: str
|
||||
api_key: str
|
||||
|
||||
_attr_supports_streaming = True
|
||||
@@ -78,7 +77,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=f"/{api_base_path}/models"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.ok:
|
||||
@@ -96,7 +95,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/models",
|
||||
timeout=5, # quick timeout
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
@@ -212,7 +211,6 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
return response_text, tool_calls
|
||||
|
||||
|
||||
# FIXME: this class is mostly broken
|
||||
class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
|
||||
@@ -224,21 +222,21 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
_last_response_id: str | None = None
|
||||
_last_response_id_time: datetime.datetime | None = None
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
||||
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
||||
super().__init__(hass, client_options)
|
||||
self.api_host = format_url(
|
||||
hostname=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
hostname=client_options[CONF_HOST],
|
||||
port=client_options[CONF_PORT],
|
||||
ssl=client_options[CONF_SSL],
|
||||
path="/" + client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
)
|
||||
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
|
||||
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
|
||||
|
||||
def _responses_params(self, conversation: List[conversation.Content], api_base_path: str) -> Tuple[str, Dict[str, Any]]:
|
||||
def _responses_params(self, conversation: List[conversation.Content], entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/responses"
|
||||
endpoint = "/responses"
|
||||
# Find the last user message in the conversation and use its content as the input
|
||||
input_text: str | None = None
|
||||
for msg in reversed(conversation):
|
||||
@@ -256,9 +254,9 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
request_params["input"] = input_text
|
||||
|
||||
# Assign previous_response_id if relevant
|
||||
if self._last_response_id and self._last_response_id_time and self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
|
||||
if self._last_response_id and self._last_response_id_time and entity_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
|
||||
# If the last response was generated recently, use it as a context
|
||||
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=self.entry.options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
|
||||
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=entity_options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
|
||||
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
|
||||
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
|
||||
if last_conversation_age < configured_memory_time:
|
||||
@@ -303,7 +301,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
if response_json["status"] != "completed":
|
||||
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> str | None:
|
||||
def _extract_response(self, response_json: dict) -> str | None:
|
||||
self._validate_response_payload(response_json)
|
||||
self._check_response_status(response_json)
|
||||
|
||||
@@ -348,11 +346,12 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
|
||||
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL)
|
||||
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
endpoint, additional_params = self._responses_params(conversation, api_base_path=entity_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
endpoint, additional_params = self._responses_params(conversation, entity_options)
|
||||
|
||||
request_params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"model": model_name,
|
||||
}
|
||||
request_params.update(additional_params)
|
||||
|
||||
@@ -374,7 +373,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
response_json = await response.json()
|
||||
|
||||
try:
|
||||
text = self._extract_response(response_json, llm_api, user_input)
|
||||
text = self._extract_response(response_json)
|
||||
return TextGenerationResult(response=text, response_streamed=False)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
|
||||
|
||||
@@ -4,23 +4,22 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_TOP_K,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_MIN_P,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
@@ -37,12 +36,15 @@ from custom_components.llama_conversation.backends.generic_openai import Generic
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class TextGenerationWebuiClient(GenericOpenAIAPIClient):
|
||||
admin_key: str
|
||||
admin_key: Optional[str]
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
||||
await super()._async_load_model(entry)
|
||||
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
|
||||
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
||||
super().__init__(hass, client_options)
|
||||
|
||||
self.admin_key = client_options.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)
|
||||
|
||||
async def _async_load_model(self, entity_options: dict[str, Any]) -> None:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL)
|
||||
try:
|
||||
headers = {}
|
||||
session = async_get_clientsession(self.hass)
|
||||
@@ -58,18 +60,16 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
|
||||
currently_loaded_result = await response.json()
|
||||
|
||||
loaded_model = currently_loaded_result["model_name"]
|
||||
if loaded_model == self.model_name:
|
||||
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
||||
if loaded_model == model_name:
|
||||
_LOGGER.info(f"Model {model_name} is already loaded on the remote backend.")
|
||||
return
|
||||
else:
|
||||
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
||||
_LOGGER.info(f"Model is not {model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
async with session.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
json={
|
||||
"model_name": self.model_name,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# "args": {},
|
||||
"model_name": model_name,
|
||||
},
|
||||
headers=headers
|
||||
) as response:
|
||||
@@ -98,55 +98,14 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
|
||||
return endpoint, request_params
|
||||
|
||||
|
||||
async def _async_validate_text_generation_webui(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to text-generation-webui and that the model exists on the remote server
|
||||
|
||||
:param user_input: the input dictionary used to build the connection
|
||||
:return: a tuple of (error message name, exception detail); both can be None
|
||||
"""
|
||||
try:
|
||||
headers = {}
|
||||
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
format_url(
|
||||
hostname=self.model_config[CONF_HOST],
|
||||
port=self.model_config[CONF_PORT],
|
||||
ssl=self.model_config[CONF_SSL],
|
||||
path="/v1/internal/model/list"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models = await response.json()
|
||||
|
||||
for model in models["model_names"]:
|
||||
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
|
||||
return None, None, []
|
||||
|
||||
return "missing_model_api", None, models["model_names"]
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.info("Connection error was: %s", repr(ex))
|
||||
return "failed_to_connect", ex, []
|
||||
|
||||
class LlamaCppServerClient(GenericOpenAIAPIClient):
|
||||
grammar: str
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry):
|
||||
await super()._async_load_model(entry)
|
||||
def __init__(self, hass: HomeAssistant, client_options: Dict[str, Any]):
|
||||
super().__init__(hass, client_options)
|
||||
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._load_model, entry
|
||||
)
|
||||
|
||||
def _load_model(self, entry: ConfigEntry):
|
||||
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
|
||||
grammar_file_name = client_options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), grammar_file_name)) as f:
|
||||
self.grammar = "".join(f.readlines())
|
||||
|
||||
def _chat_completion_params(self, entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
|
||||
@@ -217,6 +217,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Local LLM Conversation."""
|
||||
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 0
|
||||
|
||||
install_wheel_task = None
|
||||
install_wheel_error = None
|
||||
|
||||
Reference in New Issue
Block a user