diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index ef4eaf6..d60e6f0 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import os from typing import Final from homeassistant.config_entries import ConfigEntry, ConfigSubentry @@ -50,7 +51,7 @@ from .backends.llamacpp import LlamaCppClient from .backends.generic_openai import GenericOpenAIAPIClient, GenericOpenAIResponsesAPIClient from .backends.tailored_openai import TextGenerationWebuiClient, LlamaCppServerClient from .backends.ollama import OllamaAPIClient -from .utils import get_llama_cpp_python_version +from .utils import get_llama_cpp_python_version, download_model_from_hf _LOGGER = logging.getLogger(__name__) @@ -161,6 +162,27 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigE hass.config_entries.async_update_entry(config_entry, title=entry_title, data=entry_data, options=entry_options, version=3) _LOGGER.debug("Migration to subentries complete") + + if config_entry.version == 3 and config_entry.minor_version == 0: + # add the downloaded model file to options if missing + if config_entry.data.get(CONF_BACKEND_TYPE) == BACKEND_TYPE_LLAMA_CPP: + for subentry in config_entry.subentries.values(): + if subentry.data.get(CONF_DOWNLOADED_MODEL_FILE) is None: + model_name = subentry.data[CONF_CHAT_MODEL] + quantization_type = subentry.data[CONF_DOWNLOADED_MODEL_QUANTIZATION] + storage_folder = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "models") + + new_options = dict(subentry.data) + file_name = await hass.async_add_executor_job(download_model_from_hf, model_name, quantization_type, storage_folder, True) + new_options[CONF_DOWNLOADED_MODEL_FILE] = file_name + + hass.config_entries.async_update_subentry( + config_entry, subentry, data=MappingProxyType(new_options) + ) + + hass.config_entries.async_update_entry(config_entry, minor_version=1) + + _LOGGER.debug("Migration to add downloaded model file complete") return True diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index b2baf19..53862c6 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -220,7 +220,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN): """Handle a config flow for Local LLM Conversation.""" VERSION = 3 - MINOR_VERSION = 0 + MINOR_VERSION = 1 install_wheel_task = None install_wheel_error = None @@ -1071,10 +1071,10 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow): if not self.download_task: model_name = self.model_config[CONF_CHAT_MODEL] quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION] - storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models") async def download_task(): + # return await self.hass.async_add_executor_job( await self.hass.async_add_executor_job( download_model_from_hf, model_name, quantization_type, storage_folder ) diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index e9d122c..3a3bee0 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -136,7 +136,7 @@ def custom_custom_serializer(value): return cv.custom_serializer(value) -def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str): +def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str, file_lookup_only: bool = False): try: from huggingface_hub import hf_hub_download, HfFileSystem except Exception as ex: @@ -162,6 +162,7 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold repo_type="model", filename=wanted_file[0].removeprefix(model_name + "/"), cache_dir=storage_folder, + local_files_only=file_lookup_only ) def _load_extension():