mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-06 14:55:03 -05:00
There is a subtle change in behaviour with the new model probe API. Previously, checks for model types was done in a specific order. For example, we did all main model checks before LoRA checks. With the new API, the order of checks has changed. Check ordering is as follows: - New API checks are run first, then legacy API checks. - New API checks categorized by their speed. When we run new API checks, we sort them from fastest to slowest, and run them in that order. This is a performance optimization. Currently, LoRA and LLaVA models are the only model types with the new API. Checks for them are thus run first. LoRA checks involve checking the state dict for presence of keys with specific prefixes. We expect these keys to only exist in LoRAs. It turns out that main models may have some of these keys. For example, this model has keys that match the LoRA prefix `lora_te_`: https://civitai.com/models/134442/helloyoung25d Under the old probe, we'd do the main model checks first and correctly identify this as a main model. But with the new setup, we do the LoRA check first, and those pass. So we import this model as a LoRA. Thankfully, the old probe still exists. For now, the new probe is fully disabled. It was only called in one spot. I've also added the example affected model as a test case for the model probe. Right now, this causes the test to fail, and I've marked the test as xfail. CI will pass. Once we enable the new API again, the xfail will pass, and CI will fail, and we'll be reminded to update the test.
949 lines
41 KiB
Python
949 lines
41 KiB
Python
"""Model installation class."""
|
|
|
|
import locale
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from queue import Empty, Queue
|
|
from shutil import copyfile, copytree, move, rmtree
|
|
from tempfile import mkdtemp
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import torch
|
|
import yaml
|
|
from huggingface_hub import HfFolder
|
|
from pydantic.networks import AnyHttpUrl
|
|
from pydantic_core import Url
|
|
from requests import Session
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
|
from invokeai.app.services.model_install.model_install_common import (
|
|
MODEL_SOURCE_TO_TYPE_MAP,
|
|
HFModelSource,
|
|
InstallStatus,
|
|
LocalModelSource,
|
|
ModelInstallJob,
|
|
ModelSource,
|
|
StringLikeSource,
|
|
URLModelSource,
|
|
)
|
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
|
from invokeai.backend.model_manager.config import (
|
|
AnyModelConfig,
|
|
CheckpointConfigBase,
|
|
InvalidModelConfigException,
|
|
)
|
|
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
|
from invokeai.backend.model_manager.metadata import (
|
|
AnyModelRepoMetadata,
|
|
HuggingFaceMetadataFetch,
|
|
ModelMetadataFetchBase,
|
|
ModelMetadataWithFiles,
|
|
RemoteModelFile,
|
|
)
|
|
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
|
from invokeai.backend.model_manager.search import ModelSearch
|
|
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
|
from invokeai.backend.util import InvokeAILogger
|
|
from invokeai.backend.util.catch_sigint import catch_sigint
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
from invokeai.backend.util.util import slugify
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
|
|
|
|
TMPDIR_PREFIX = "tmpinstall_"
|
|
|
|
|
|
class ModelInstallService(ModelInstallServiceBase):
|
|
"""class for InvokeAI model installation."""
|
|
|
|
def __init__(
|
|
self,
|
|
app_config: InvokeAIAppConfig,
|
|
record_store: ModelRecordServiceBase,
|
|
download_queue: DownloadQueueServiceBase,
|
|
event_bus: Optional["EventServiceBase"] = None,
|
|
session: Optional[Session] = None,
|
|
):
|
|
"""
|
|
Initialize the installer object.
|
|
|
|
:param app_config: InvokeAIAppConfig object
|
|
:param record_store: Previously-opened ModelRecordService database
|
|
:param event_bus: Optional EventService object
|
|
"""
|
|
self._app_config = app_config
|
|
self._record_store = record_store
|
|
self._event_bus = event_bus
|
|
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
|
self._install_jobs: List[ModelInstallJob] = []
|
|
self._install_queue: Queue[ModelInstallJob] = Queue()
|
|
self._lock = threading.Lock()
|
|
self._stop_event = threading.Event()
|
|
self._downloads_changed_event = threading.Event()
|
|
self._install_completed_event = threading.Event()
|
|
self._download_queue = download_queue
|
|
self._download_cache: Dict[int, ModelInstallJob] = {}
|
|
self._running = False
|
|
self._session = session
|
|
self._install_thread: Optional[threading.Thread] = None
|
|
self._next_job_id = 0
|
|
|
|
@property
|
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
|
return self._app_config
|
|
|
|
@property
|
|
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
|
return self._record_store
|
|
|
|
@property
|
|
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
|
|
return self._event_bus
|
|
|
|
# make the invoker optional here because we don't need it and it
|
|
# makes the installer harder to use outside the web app
|
|
def start(self, invoker: Optional[Invoker] = None) -> None:
|
|
"""Start the installer thread."""
|
|
|
|
with self._lock:
|
|
if self._running:
|
|
raise Exception("Attempt to start the installer service twice")
|
|
self._start_installer_thread()
|
|
self._remove_dangling_install_dirs()
|
|
self._migrate_yaml()
|
|
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
|
|
# We should only do the scan when the flag is set (which should only be set when testing).
|
|
if self.app_config.scan_models_on_startup:
|
|
with catch_sigint():
|
|
self._register_orphaned_models()
|
|
|
|
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
|
|
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
|
|
# want to alert the user.
|
|
for model in self._scan_for_missing_models():
|
|
self._logger.warning(f"Missing model file: {model.name} at {model.path}")
|
|
|
|
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
|
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
|
|
if not self._running:
|
|
raise Exception("Attempt to stop the install service before it was started")
|
|
self._logger.debug("calling stop_event.set()")
|
|
self._stop_event.set()
|
|
self._clear_pending_jobs()
|
|
self._download_cache.clear()
|
|
assert self._install_thread is not None
|
|
self._install_thread.join()
|
|
self._running = False
|
|
|
|
def _clear_pending_jobs(self) -> None:
|
|
for job in self.list_jobs():
|
|
if not job.in_terminal_state:
|
|
self._logger.warning("Cancelling job {job.id}")
|
|
self.cancel_job(job)
|
|
while True:
|
|
try:
|
|
job = self._install_queue.get(block=False)
|
|
self._install_queue.task_done()
|
|
except Empty:
|
|
break
|
|
|
|
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
|
if self._stop_event.is_set():
|
|
self.cancel_job(job)
|
|
else:
|
|
self._install_queue.put(job)
|
|
|
|
def register_path(
|
|
self,
|
|
model_path: Union[Path, str],
|
|
config: Optional[ModelRecordChanges] = None,
|
|
) -> str: # noqa D102
|
|
model_path = Path(model_path)
|
|
config = config or ModelRecordChanges()
|
|
if not config.source:
|
|
config.source = model_path.resolve().as_posix()
|
|
config.source_type = ModelSourceType.Path
|
|
return self._register(model_path, config)
|
|
|
|
def install_path(
|
|
self,
|
|
model_path: Union[Path, str],
|
|
config: Optional[ModelRecordChanges] = None,
|
|
) -> str: # noqa D102
|
|
model_path = Path(model_path)
|
|
config = config or ModelRecordChanges()
|
|
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
|
|
|
|
if preferred_name := config.name:
|
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
|
|
|
dest_path = (
|
|
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
|
|
)
|
|
try:
|
|
new_path = self._copy_model(model_path, dest_path)
|
|
except FileExistsError as excp:
|
|
raise DuplicateModelException(
|
|
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
|
) from excp
|
|
|
|
return self._register(
|
|
new_path,
|
|
config,
|
|
info,
|
|
)
|
|
|
|
def heuristic_import(
|
|
self,
|
|
source: str,
|
|
config: Optional[ModelRecordChanges] = None,
|
|
access_token: Optional[str] = None,
|
|
inplace: Optional[bool] = False,
|
|
) -> ModelInstallJob:
|
|
"""Install a model using pattern matching to infer the type of source."""
|
|
source_obj = self._guess_source(source)
|
|
if isinstance(source_obj, LocalModelSource):
|
|
source_obj.inplace = inplace
|
|
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
|
source_obj.access_token = access_token
|
|
return self.import_model(source_obj, config)
|
|
|
|
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
|
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
|
if similar_jobs:
|
|
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
|
return similar_jobs[0]
|
|
|
|
if isinstance(source, LocalModelSource):
|
|
install_job = self._import_local_model(source, config)
|
|
self._put_in_queue(install_job) # synchronously install
|
|
elif isinstance(source, HFModelSource):
|
|
install_job = self._import_from_hf(source, config)
|
|
elif isinstance(source, URLModelSource):
|
|
install_job = self._import_from_url(source, config)
|
|
else:
|
|
raise ValueError(f"Unsupported model source: '{type(source)}'")
|
|
|
|
self._install_jobs.append(install_job)
|
|
return install_job
|
|
|
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
|
return self._install_jobs
|
|
|
|
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
|
return [x for x in self._install_jobs if x.source == source]
|
|
|
|
def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102
|
|
jobs = [x for x in self._install_jobs if x.id == id]
|
|
if not jobs:
|
|
raise ValueError(f"No job with id {id} known")
|
|
assert len(jobs) == 1
|
|
assert isinstance(jobs[0], ModelInstallJob)
|
|
return jobs[0]
|
|
|
|
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
|
start = time.time()
|
|
while not job.in_terminal_state:
|
|
if self._install_completed_event.wait(timeout=5): # in case we miss an event
|
|
self._install_completed_event.clear()
|
|
if timeout > 0 and time.time() - start > timeout:
|
|
raise TimeoutError("Timeout exceeded")
|
|
return job
|
|
|
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
|
"""Block until all installation jobs are done."""
|
|
start = time.time()
|
|
while len(self._download_cache) > 0:
|
|
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
|
self._downloads_changed_event.clear()
|
|
if timeout > 0 and time.time() - start > timeout:
|
|
raise TimeoutError("Timeout exceeded")
|
|
self._install_queue.join()
|
|
|
|
return self._install_jobs
|
|
|
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
|
"""Cancel the indicated job."""
|
|
job.cancel()
|
|
self._logger.warning(f"Cancelling {job.source}")
|
|
if dj := job._multifile_job:
|
|
self._download_queue.cancel_job(dj)
|
|
|
|
def prune_jobs(self) -> None:
|
|
"""Prune all completed and errored jobs."""
|
|
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
|
self._install_jobs = unfinished_jobs
|
|
|
|
def _migrate_yaml(self) -> None:
|
|
db_models = self.record_store.all_models()
|
|
|
|
legacy_models_yaml_path = (
|
|
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
|
)
|
|
|
|
# The old path may be relative to the root path
|
|
if not legacy_models_yaml_path.exists():
|
|
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
|
|
|
if legacy_models_yaml_path.exists():
|
|
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
|
legacy_models_yaml = yaml.safe_load(file)
|
|
|
|
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
|
yaml_version = yaml_metadata.get("version")
|
|
|
|
if yaml_version != "3.0.0":
|
|
raise ValueError(
|
|
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
|
)
|
|
|
|
self._logger.info(
|
|
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
|
|
)
|
|
|
|
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
|
|
for model_key, stanza in legacy_models_yaml.items():
|
|
_, _, model_name = str(model_key).split("/")
|
|
model_path = Path(stanza["path"])
|
|
if not model_path.is_absolute():
|
|
model_path = self._app_config.models_path / model_path
|
|
model_path = model_path.resolve()
|
|
|
|
config = ModelRecordChanges(
|
|
name=model_name,
|
|
description=stanza.get("description"),
|
|
)
|
|
legacy_config_path = stanza.get("config")
|
|
if legacy_config_path:
|
|
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
|
legacy_config_path = self._app_config.root_path / legacy_config_path
|
|
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
|
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
|
config.config_path = str(legacy_config_path)
|
|
try:
|
|
id = self.register_path(model_path=model_path, config=config)
|
|
self._logger.info(f"Migrated {model_name} with id {id}")
|
|
except Exception as e:
|
|
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
|
|
|
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
|
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
|
|
|
# Unset the path - we are done with it either way
|
|
self._app_config.legacy_models_yaml_path = None
|
|
|
|
def unregister(self, key: str) -> None: # noqa D102
|
|
self.record_store.del_model(key)
|
|
|
|
def delete(self, key: str) -> None: # noqa D102
|
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
|
model = self.record_store.get_model(key)
|
|
model_path = self.app_config.models_path / model.path
|
|
|
|
if model_path.is_relative_to(self.app_config.models_path):
|
|
# If the models is in the Invoke-managed models dir, we delete it
|
|
self.unconditionally_delete(key)
|
|
else:
|
|
# Else we only unregister it, leaving the file in place
|
|
self.unregister(key)
|
|
|
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
|
model = self.record_store.get_model(key)
|
|
model_path = self.app_config.models_path / model.path
|
|
if model_path.is_file() or model_path.is_symlink():
|
|
model_path.unlink()
|
|
elif model_path.is_dir():
|
|
rmtree(model_path)
|
|
self.unregister(key)
|
|
|
|
@classmethod
|
|
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
|
escaped_source = slugify(str(source))
|
|
return app_config.download_cache_path / escaped_source
|
|
|
|
def download_and_cache_model(
|
|
self,
|
|
source: str | AnyHttpUrl,
|
|
) -> Path:
|
|
"""Download the model file located at source to the models cache and return its Path."""
|
|
model_path = self._download_cache_path(str(source), self._app_config)
|
|
|
|
# We expect the cache directory to contain one and only one downloaded file or directory.
|
|
# We don't know the file's name in advance, as it is set by the download
|
|
# content-disposition header.
|
|
if model_path.exists():
|
|
contents: List[Path] = list(model_path.iterdir())
|
|
if len(contents) > 0:
|
|
return contents[0]
|
|
|
|
model_path.mkdir(parents=True, exist_ok=True)
|
|
model_source = self._guess_source(str(source))
|
|
remote_files, _ = self._remote_files_from_source(model_source)
|
|
job = self._multifile_download(
|
|
dest=model_path,
|
|
remote_files=remote_files,
|
|
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
|
)
|
|
files_string = "file" if len(remote_files) == 1 else "files"
|
|
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
|
self._download_queue.wait_for_job(job)
|
|
if job.complete:
|
|
assert job.download_path is not None
|
|
return job.download_path
|
|
else:
|
|
raise Exception(job.error)
|
|
|
|
def _remote_files_from_source(
|
|
self, source: ModelSource
|
|
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
|
metadata = None
|
|
if isinstance(source, HFModelSource):
|
|
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
return (
|
|
metadata.download_urls(
|
|
variant=source.variant or self._guess_variant(),
|
|
subfolder=source.subfolder,
|
|
session=self._session,
|
|
),
|
|
metadata,
|
|
)
|
|
|
|
if isinstance(source, URLModelSource):
|
|
try:
|
|
fetcher = self.get_fetcher_from_url(str(source.url))
|
|
kwargs: dict[str, Any] = {"session": self._session}
|
|
metadata = fetcher(**kwargs).from_url(source.url)
|
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
return metadata.download_urls(session=self._session), metadata
|
|
except ValueError:
|
|
pass
|
|
|
|
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
|
|
|
raise Exception(f"No files associated with {source}")
|
|
|
|
def _guess_source(self, source: str) -> ModelSource:
|
|
"""Turn a source string into a ModelSource object."""
|
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
|
source_obj: Optional[StringLikeSource] = None
|
|
source_stripped = source.strip('"')
|
|
|
|
if Path(source_stripped).exists(): # A local file or directory
|
|
source_obj = LocalModelSource(path=Path(source_stripped))
|
|
elif match := re.match(hf_repoid_re, source):
|
|
source_obj = HFModelSource(
|
|
repo_id=match.group(1),
|
|
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
|
|
subfolder=Path(match.group(3)) if match.group(3) else None,
|
|
)
|
|
elif re.match(r"^https?://[^/]+", source):
|
|
source_obj = URLModelSource(
|
|
url=Url(source),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported model source: '{source}'")
|
|
return source_obj
|
|
|
|
# --------------------------------------------------------------------------------------------
|
|
# Internal functions that manage the installer threads
|
|
# --------------------------------------------------------------------------------------------
|
|
def _start_installer_thread(self) -> None:
|
|
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
|
|
self._install_thread.start()
|
|
self._running = True
|
|
|
|
def _install_next_item(self) -> None:
|
|
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
|
|
while True:
|
|
if self._stop_event.is_set():
|
|
break
|
|
self._logger.debug(f"Installer thread {threading.get_ident()} polling")
|
|
try:
|
|
job = self._install_queue.get(timeout=1)
|
|
except Empty:
|
|
continue
|
|
assert job.local_path is not None
|
|
try:
|
|
if job.cancelled:
|
|
self._signal_job_cancelled(job)
|
|
|
|
elif job.errored:
|
|
self._signal_job_errored(job)
|
|
|
|
elif job.waiting or job.downloads_done:
|
|
self._register_or_install(job)
|
|
|
|
except Exception as e:
|
|
# Expected errors include InvalidModelConfigException, DuplicateModelException, OSError, but we must
|
|
# gracefully handle _any_ error here.
|
|
self._set_error(job, e)
|
|
|
|
finally:
|
|
# if this is an install of a remote file, then clean up the temporary directory
|
|
if job._install_tmpdir is not None:
|
|
rmtree(job._install_tmpdir)
|
|
self._install_completed_event.set()
|
|
self._install_queue.task_done()
|
|
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
|
|
|
|
def _register_or_install(self, job: ModelInstallJob) -> None:
|
|
# local jobs will be in waiting state, remote jobs will be downloading state
|
|
job.total_bytes = self._stat_size(job.local_path)
|
|
job.bytes = job.total_bytes
|
|
self._signal_job_running(job)
|
|
job.config_in.source = str(job.source)
|
|
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
|
# enter the metadata, if there is any
|
|
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
|
job.config_in.source_api_response = job.source_metadata.api_response
|
|
|
|
if job.inplace:
|
|
key = self.register_path(job.local_path, job.config_in)
|
|
else:
|
|
key = self.install_path(job.local_path, job.config_in)
|
|
job.config_out = self.record_store.get_model(key)
|
|
self._signal_job_completed(job)
|
|
|
|
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
|
multifile_download_job = install_job._multifile_job
|
|
if multifile_download_job and any(
|
|
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
|
|
):
|
|
install_job.set_error(
|
|
InvalidModelConfigException(
|
|
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
|
)
|
|
)
|
|
else:
|
|
install_job.set_error(excp)
|
|
self._signal_job_errored(install_job)
|
|
|
|
# --------------------------------------------------------------------------------------------
|
|
# Internal functions that manage the models directory
|
|
# --------------------------------------------------------------------------------------------
|
|
def _remove_dangling_install_dirs(self) -> None:
|
|
"""Remove leftover tmpdirs from aborted installs."""
|
|
path = self._app_config.models_path
|
|
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
|
|
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
|
|
rmtree(tmpdir)
|
|
|
|
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
|
|
"""Scan the models directory for missing models and return a list of them."""
|
|
missing_models: list[AnyModelConfig] = []
|
|
for model_config in self.record_store.all_models():
|
|
if not (self.app_config.models_path / model_config.path).resolve().exists():
|
|
missing_models.append(model_config)
|
|
return missing_models
|
|
|
|
def _register_orphaned_models(self) -> None:
|
|
"""Scan the invoke-managed models directory for orphaned models and registers them.
|
|
|
|
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
|
only situations in which we may have orphaned models in the models directory.
|
|
"""
|
|
installed_model_paths = {
|
|
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
|
}
|
|
|
|
# The bool returned by this callback determines if the model is added to the list of models found by the search
|
|
def on_model_found(model_path: Path) -> bool:
|
|
resolved_path = model_path.resolve()
|
|
# Already registered models should be in the list of found models, but not re-registered.
|
|
if resolved_path in installed_model_paths:
|
|
return True
|
|
# Skip core models entirely - these aren't registered with the model manager.
|
|
for special_directory in [
|
|
self.app_config.models_path / "core",
|
|
self.app_config.convert_cache_dir,
|
|
self.app_config.download_cache_dir,
|
|
]:
|
|
if resolved_path.is_relative_to(special_directory):
|
|
return False
|
|
try:
|
|
model_id = self.register_path(model_path)
|
|
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
|
except DuplicateModelException:
|
|
# In case a duplicate models sneaks by, we will ignore this error - we "found" the model
|
|
pass
|
|
return True
|
|
|
|
self._logger.info(f"Scanning {self._app_config.models_path} for orphaned models")
|
|
search = ModelSearch(on_model_found=on_model_found)
|
|
found_models = search.search(self._app_config.models_path)
|
|
self._logger.info(f"{len(found_models)} new models registered")
|
|
|
|
def sync_model_path(self, key: str) -> AnyModelConfig:
|
|
"""
|
|
Move model into the location indicated by its basetype, type and name.
|
|
|
|
Call this after updating a model's attributes in order to move
|
|
the model's path into the location indicated by its basetype, type and
|
|
name. Applies only to models whose paths are within the root `models_dir`
|
|
directory.
|
|
|
|
May raise an UnknownModelException.
|
|
"""
|
|
model = self.record_store.get_model(key)
|
|
models_dir = self.app_config.models_path
|
|
old_path = self.app_config.models_path / model.path
|
|
|
|
if not old_path.is_relative_to(models_dir):
|
|
# The model is not in the models directory - we don't need to move it.
|
|
return model
|
|
|
|
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
|
|
|
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
|
return model
|
|
|
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
|
new_path = self._move_model(old_path, new_path)
|
|
model.path = new_path.relative_to(models_dir).as_posix()
|
|
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
|
return model
|
|
|
|
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
|
if old_path == new_path:
|
|
return old_path
|
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if old_path.is_dir():
|
|
copytree(old_path, new_path)
|
|
else:
|
|
copyfile(old_path, new_path)
|
|
return new_path
|
|
|
|
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
|
if old_path == new_path:
|
|
return old_path
|
|
|
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# if path already exists then we jigger the name to make it unique
|
|
counter: int = 1
|
|
while new_path.exists():
|
|
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
|
if not path.exists():
|
|
new_path = path
|
|
counter += 1
|
|
move(old_path, new_path)
|
|
return new_path
|
|
|
|
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
|
|
config = config or ModelRecordChanges()
|
|
hash_algo = self._app_config.hashing_algorithm
|
|
fields = config.model_dump()
|
|
|
|
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
|
|
|
|
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
|
|
# See commit message for details.
|
|
# try:
|
|
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
|
|
# except InvalidModelConfigException:
|
|
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
|
|
|
def _register(
|
|
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
|
) -> str:
|
|
config = config or ModelRecordChanges()
|
|
|
|
info = info or self._probe(model_path, config)
|
|
|
|
model_path = model_path.resolve()
|
|
|
|
# Models in the Invoke-managed models dir should use relative paths.
|
|
if model_path.is_relative_to(self.app_config.models_path):
|
|
model_path = model_path.relative_to(self.app_config.models_path)
|
|
|
|
info.path = model_path.as_posix()
|
|
|
|
if isinstance(info, CheckpointConfigBase):
|
|
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
|
|
# invoke-managed legacy config dir, we use a relative path.
|
|
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
|
|
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
|
|
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
|
|
info.config_path = legacy_config_path.as_posix()
|
|
self.record_store.add_model(info)
|
|
return info.key
|
|
|
|
def _next_id(self) -> int:
|
|
with self._lock:
|
|
id = self._next_job_id
|
|
self._next_job_id += 1
|
|
return id
|
|
|
|
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
|
"""Guess the best HuggingFace variant type to download."""
|
|
precision = TorchDevice.choose_torch_dtype()
|
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
|
|
|
def _import_local_model(
|
|
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
|
) -> ModelInstallJob:
|
|
return ModelInstallJob(
|
|
id=self._next_id(),
|
|
source=source,
|
|
config_in=config or ModelRecordChanges(),
|
|
local_path=Path(source.path),
|
|
inplace=source.inplace or False,
|
|
)
|
|
|
|
def _import_from_hf(
|
|
self,
|
|
source: HFModelSource,
|
|
config: Optional[ModelRecordChanges] = None,
|
|
) -> ModelInstallJob:
|
|
# Add user's cached access token to HuggingFace requests
|
|
if source.access_token is None:
|
|
source.access_token = HfFolder.get_token()
|
|
remote_files, metadata = self._remote_files_from_source(source)
|
|
return self._import_remote_model(
|
|
source=source,
|
|
config=config,
|
|
remote_files=remote_files,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _import_from_url(
|
|
self,
|
|
source: URLModelSource,
|
|
config: Optional[ModelRecordChanges] = None,
|
|
) -> ModelInstallJob:
|
|
remote_files, metadata = self._remote_files_from_source(source)
|
|
return self._import_remote_model(
|
|
source=source,
|
|
config=config,
|
|
metadata=metadata,
|
|
remote_files=remote_files,
|
|
)
|
|
|
|
def _import_remote_model(
|
|
self,
|
|
source: HFModelSource | URLModelSource,
|
|
remote_files: List[RemoteModelFile],
|
|
metadata: Optional[AnyModelRepoMetadata],
|
|
config: Optional[ModelRecordChanges],
|
|
) -> ModelInstallJob:
|
|
if len(remote_files) == 0:
|
|
raise ValueError(f"{source}: No downloadable files found")
|
|
destdir = Path(
|
|
mkdtemp(
|
|
dir=self._app_config.models_path,
|
|
prefix=TMPDIR_PREFIX,
|
|
)
|
|
)
|
|
install_job = ModelInstallJob(
|
|
id=self._next_id(),
|
|
source=source,
|
|
config_in=config or ModelRecordChanges(),
|
|
source_metadata=metadata,
|
|
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
|
bytes=0,
|
|
total_bytes=0,
|
|
)
|
|
# remember the temporary directory for later removal
|
|
install_job._install_tmpdir = destdir
|
|
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
|
|
|
multifile_job = self._multifile_download(
|
|
remote_files=remote_files,
|
|
dest=destdir,
|
|
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
|
access_token=source.access_token,
|
|
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
|
)
|
|
self._download_cache[multifile_job.id] = install_job
|
|
install_job._multifile_job = multifile_job
|
|
|
|
files_string = "file" if len(remote_files) == 1 else "files"
|
|
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
|
|
self._logger.debug(f"remote_files={remote_files}")
|
|
self._download_queue.submit_multifile_download(multifile_job)
|
|
return install_job
|
|
|
|
def _stat_size(self, path: Path) -> int:
|
|
size = 0
|
|
if path.is_file():
|
|
size = path.stat().st_size
|
|
elif path.is_dir():
|
|
for root, _, files in os.walk(path):
|
|
size += sum(self._stat_size(Path(root, x)) for x in files)
|
|
return size
|
|
|
|
def _multifile_download(
|
|
self,
|
|
remote_files: List[RemoteModelFile],
|
|
dest: Path,
|
|
subfolder: Optional[Path] = None,
|
|
access_token: Optional[str] = None,
|
|
submit_job: bool = True,
|
|
) -> MultiFileDownloadJob:
|
|
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
|
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
|
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
|
if subfolder:
|
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
|
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
|
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
|
path_to_add = Path(f"{top}_{subfolder_rename}")
|
|
else:
|
|
path_to_remove = Path(".")
|
|
path_to_add = Path(".")
|
|
|
|
parts: List[RemoteModelFile] = []
|
|
for model_file in remote_files:
|
|
assert model_file.size is not None
|
|
parts.append(
|
|
RemoteModelFile(
|
|
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
|
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
|
)
|
|
)
|
|
|
|
return self._download_queue.multifile_download(
|
|
parts=parts,
|
|
dest=dest,
|
|
access_token=access_token,
|
|
submit_job=submit_job,
|
|
on_start=self._download_started_callback,
|
|
on_progress=self._download_progress_callback,
|
|
on_complete=self._download_complete_callback,
|
|
on_error=self._download_error_callback,
|
|
on_cancelled=self._download_cancelled_callback,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Callbacks are executed by the download queue in a separate thread
|
|
# ------------------------------------------------------------------
|
|
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
|
with self._lock:
|
|
if install_job := self._download_cache.get(download_job.id, None):
|
|
install_job.status = InstallStatus.DOWNLOADING
|
|
|
|
if install_job.local_path == install_job._install_tmpdir: # first time
|
|
assert download_job.download_path
|
|
install_job.local_path = download_job.download_path
|
|
install_job.download_parts = download_job.download_parts
|
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
|
install_job.total_bytes = download_job.total_bytes
|
|
self._signal_job_download_started(install_job)
|
|
|
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
|
with self._lock:
|
|
if install_job := self._download_cache.get(download_job.id, None):
|
|
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
|
self._download_queue.cancel_job(download_job)
|
|
else:
|
|
# update sizes
|
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
|
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
|
|
self._signal_job_downloading(install_job)
|
|
|
|
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
|
with self._lock:
|
|
if install_job := self._download_cache.pop(download_job.id, None):
|
|
self._signal_job_downloads_done(install_job)
|
|
self._put_in_queue(install_job) # this starts the installation and registration
|
|
|
|
# Let other threads know that the number of downloads has changed
|
|
self._downloads_changed_event.set()
|
|
|
|
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
|
with self._lock:
|
|
if install_job := self._download_cache.pop(download_job.id, None):
|
|
assert excp is not None
|
|
self._set_error(install_job, excp)
|
|
self._download_queue.cancel_job(download_job)
|
|
|
|
# Let other threads know that the number of downloads has changed
|
|
self._downloads_changed_event.set()
|
|
|
|
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
|
with self._lock:
|
|
if install_job := self._download_cache.pop(download_job.id, None):
|
|
self._downloads_changed_event.set()
|
|
# if install job has already registered an error, then do not replace its status with cancelled
|
|
if not install_job.errored:
|
|
install_job.cancel()
|
|
|
|
# Let other threads know that the number of downloads has changed
|
|
self._downloads_changed_event.set()
|
|
|
|
# ------------------------------------------------------------------------------------------------
|
|
# Internal methods that put events on the event bus
|
|
# ------------------------------------------------------------------------------------------------
|
|
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
|
job.status = InstallStatus.RUNNING
|
|
self._logger.info(f"Model install started: {job.source}")
|
|
if self._event_bus:
|
|
self._event_bus.emit_model_install_started(job)
|
|
|
|
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
|
|
if self._event_bus:
|
|
assert job._multifile_job is not None
|
|
assert job.bytes is not None
|
|
assert job.total_bytes is not None
|
|
self._event_bus.emit_model_install_download_started(job)
|
|
|
|
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
|
if self._event_bus:
|
|
assert job._multifile_job is not None
|
|
assert job.bytes is not None
|
|
assert job.total_bytes is not None
|
|
self._event_bus.emit_model_install_download_progress(job)
|
|
|
|
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
|
job.status = InstallStatus.DOWNLOADS_DONE
|
|
self._logger.info(f"Model download complete: {job.source}")
|
|
if self._event_bus:
|
|
self._event_bus.emit_model_install_downloads_complete(job)
|
|
|
|
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
|
job.status = InstallStatus.COMPLETED
|
|
assert job.config_out
|
|
self._logger.info(f"Model install complete: {job.source}")
|
|
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
|
if self._event_bus:
|
|
assert job.local_path is not None
|
|
assert job.config_out is not None
|
|
self._event_bus.emit_model_install_complete(job)
|
|
|
|
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
|
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
|
if self._event_bus:
|
|
assert job.error_type is not None
|
|
assert job.error is not None
|
|
self._event_bus.emit_model_install_error(job)
|
|
|
|
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
|
self._logger.info(f"Model install canceled: {job.source}")
|
|
if self._event_bus:
|
|
self._event_bus.emit_model_install_cancelled(job)
|
|
|
|
@staticmethod
|
|
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
|
"""
|
|
Return a metadata fetcher appropriate for provided url.
|
|
|
|
This used to be more useful, but the number of supported model
|
|
sources has been reduced to HuggingFace alone.
|
|
"""
|
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
|
return HuggingFaceMetadataFetch
|
|
raise ValueError(f"Unsupported model source: '{url}'")
|