mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
add missing downloaded file config attribute
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
@@ -50,7 +51,7 @@ from .backends.llamacpp import LlamaCppClient
|
||||
from .backends.generic_openai import GenericOpenAIAPIClient, GenericOpenAIResponsesAPIClient
|
||||
from .backends.tailored_openai import TextGenerationWebuiClient, LlamaCppServerClient
|
||||
from .backends.ollama import OllamaAPIClient
|
||||
from .utils import get_llama_cpp_python_version
|
||||
from .utils import get_llama_cpp_python_version, download_model_from_hf
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,6 +162,27 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE
|
||||
hass.config_entries.async_update_entry(config_entry, title=entry_title, data=entry_data, options=entry_options, version=3)
|
||||
|
||||
_LOGGER.debug("Migration to subentries complete")
|
||||
|
||||
if config_entry.version == 3 and config_entry.minor_version == 0:
|
||||
# add the downloaded model file to options if missing
|
||||
if config_entry.data.get(CONF_BACKEND_TYPE) == BACKEND_TYPE_LLAMA_CPP:
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.data.get(CONF_DOWNLOADED_MODEL_FILE) is None:
|
||||
model_name = subentry.data[CONF_CHAT_MODEL]
|
||||
quantization_type = subentry.data[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
storage_folder = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "models")
|
||||
|
||||
new_options = dict(subentry.data)
|
||||
file_name = await hass.async_add_executor_job(download_model_from_hf, model_name, quantization_type, storage_folder, True)
|
||||
new_options[CONF_DOWNLOADED_MODEL_FILE] = file_name
|
||||
|
||||
hass.config_entries.async_update_subentry(
|
||||
config_entry, subentry, data=MappingProxyType(new_options)
|
||||
)
|
||||
|
||||
hass.config_entries.async_update_entry(config_entry, minor_version=1)
|
||||
|
||||
_LOGGER.debug("Migration to add downloaded model file complete")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Local LLM Conversation."""
|
||||
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 0
|
||||
MINOR_VERSION = 1
|
||||
|
||||
install_wheel_task = None
|
||||
install_wheel_error = None
|
||||
@@ -1071,10 +1071,10 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if not self.download_task:
|
||||
model_name = self.model_config[CONF_CHAT_MODEL]
|
||||
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
|
||||
storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models")
|
||||
|
||||
async def download_task():
|
||||
# return await self.hass.async_add_executor_job(
|
||||
await self.hass.async_add_executor_job(
|
||||
download_model_from_hf, model_name, quantization_type, storage_folder
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ def custom_custom_serializer(value):
|
||||
|
||||
return cv.custom_serializer(value)
|
||||
|
||||
def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str):
|
||||
def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str, file_lookup_only: bool = False):
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download, HfFileSystem
|
||||
except Exception as ex:
|
||||
@@ -162,6 +162,7 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold
|
||||
repo_type="model",
|
||||
filename=wanted_file[0].removeprefix(model_name + "/"),
|
||||
cache_dir=storage_folder,
|
||||
local_files_only=file_lookup_only
|
||||
)
|
||||
|
||||
def _load_extension():
|
||||
|
||||
Reference in New Issue
Block a user