Files
InvokeAI/invokeai/app/services/model_install/model_install_default.py
psychedelicious c5069557f3 fix(mm): fail when model exists at path instead of finding unused new path
When installing a model, the previous, graceful logic would increment a
suffix on the destination path until found a free path for the model.

But because model file installation and record creation are not in a
transaction, we could end up moving the file successfully and fail to
create the record:
- User attempts to install an already-installed model
- Attempt to move the downloaded model from download tempdir to
destination path
- The path already exists
- Add `_1` or similar to the path until we find a path that is free
- Move the model
- Create the model record
- FK constraint violation bc we already have a model w/ that name, but
the model file has already been moved into the invokeai dir.

Closes #8416
2025-08-13 10:40:06 +10:00

948 lines
42 KiB
Python

"""Model installation class."""
import locale
import os
import re
import threading
import time
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_install.model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelSource,
StringLikeSource,
URLModelSource,
)
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
TMPDIR_PREFIX = "tmpinstall_"
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None,
):
"""
Initialize the installer object.
:param app_config: InvokeAIAppConfig object
:param record_store: Previously-opened ModelRecordService database
:param event_bus: Optional EventService object
"""
self._app_config = app_config
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs: List[ModelInstallJob] = []
self._install_queue: Queue[ModelInstallJob] = Queue()
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@property
def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store
@property
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
# makes the installer harder to use outside the web app
def start(self, invoker: Optional[Invoker] = None) -> None:
"""Start the installer thread."""
with self._lock:
if self._running:
raise Exception("Attempt to start the installer service twice")
self._start_installer_thread()
self._remove_dangling_install_dirs()
self._migrate_yaml()
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
# We should only do the scan when the flag is set (which should only be set when testing).
if self.app_config.scan_models_on_startup:
with catch_sigint():
self._register_orphaned_models()
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
# want to alert the user.
for model in self._scan_for_missing_models():
self._logger.warning(f"Missing model file: {model.name} at {model.path}")
def stop(self, invoker: Optional[Invoker] = None) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
if not self._running:
raise Exception("Attempt to stop the install service before it was started")
self._logger.debug("calling stop_event.set()")
self._stop_event.set()
self._clear_pending_jobs()
self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning(f"Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:
job = self._install_queue.get(block=False)
self._install_queue.task_done()
except Empty:
break
def _put_in_queue(self, job: ModelInstallJob) -> None:
if self._stop_event.is_set():
self.cancel_job(job)
else:
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
if not config.source:
config.source = model_path.resolve().as_posix()
config.source_type = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
if Path(model_path).is_file():
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
preferred_name = f"{preferred_name}{model_path.suffix}"
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._move_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
return self._register(
new_path,
config,
info,
)
def heuristic_import(
self,
source: str,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
"""Install a model using pattern matching to infer the type of source."""
source_obj = self._guess_source(source)
if isinstance(source_obj, LocalModelSource):
source_obj.inplace = inplace
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
return similar_jobs[0]
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
install_job = self._import_from_url(source, config)
else:
raise ValueError(f"Unsupported model source: '{type(source)}'")
self._install_jobs.append(install_job)
return install_job
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102
jobs = [x for x in self._install_jobs if x.id == id]
if not jobs:
raise ValueError(f"No job with id {id} known")
assert len(jobs) == 1
assert isinstance(jobs[0], ModelInstallJob)
return jobs[0]
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._install_completed_event.wait(timeout=5): # in case we miss an event
self._install_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
self._install_queue.join()
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
self._logger.warning(f"Cancelling {job.source}")
if dj := job._multifile_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
# The old path may be relative to the root path
if not legacy_models_yaml_path.exists():
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
if legacy_models_yaml_path.exists():
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
legacy_models_yaml = yaml.safe_load(file)
yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
if yaml_version != "3.0.0":
raise ValueError(
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
)
self._logger.info(
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
)
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
for model_key, stanza in legacy_models_yaml.items():
_, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"])
if not model_path.is_absolute():
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config = ModelRecordChanges(
name=model_name,
description=stanza.get("description"),
)
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config.config_path = str(legacy_config_path)
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
except Exception as e:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_relative_to(self.app_config.models_path):
# If the models is in the Invoke-managed models dir, we delete it
self.unconditionally_delete(key)
else:
# Else we only unregister it, leaving the file in place
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
elif model_path.is_dir():
rmtree(model_path)
self.unregister(key)
@classmethod
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return (
metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
),
metadata,
)
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
source_stripped = source.strip('"')
if Path(source_stripped).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source_stripped))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True
def _install_next_item(self) -> None:
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
while True:
if self._stop_event.is_set():
break
self._logger.debug(f"Installer thread {threading.get_ident()} polling")
try:
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
try:
if job.cancelled:
self._signal_job_cancelled(job)
elif job.errored:
self._signal_job_errored(job)
elif job.waiting or job.downloads_done:
self._register_or_install(job)
except Exception as e:
# Expected errors include InvalidModelConfigException, DuplicateModelException, OSError, but we must
# gracefully handle _any_ error here.
self._set_error(job, e)
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in.source_api_response = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
multifile_download_job = install_job._multifile_job
if multifile_download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException(
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
install_job.set_error(excp)
self._signal_job_errored(install_job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
# --------------------------------------------------------------------------------------------
def _remove_dangling_install_dirs(self) -> None:
"""Remove leftover tmpdirs from aborted installs."""
path = self._app_config.models_path
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
rmtree(tmpdir)
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
"""Scan the models directory for missing models and return a list of them."""
missing_models: list[AnyModelConfig] = []
for model_config in self.record_store.all_models():
if not (self.app_config.models_path / model_config.path).resolve().exists():
missing_models.append(model_config)
return missing_models
def _register_orphaned_models(self) -> None:
"""Scan the invoke-managed models directory for orphaned models and registers them.
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
# The bool returned by this callback determines if the model is added to the list of models found by the search
def on_model_found(model_path: Path) -> bool:
resolved_path = model_path.resolve()
# Already registered models should be in the list of found models, but not re-registered.
if resolved_path in installed_model_paths:
return True
# Skip core models entirely - these aren't registered with the model manager.
for special_directory in [
self.app_config.models_path / "core",
self.app_config.convert_cache_dir,
self.app_config.download_cache_dir,
]:
if resolved_path.is_relative_to(special_directory):
return False
try:
model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}")
except DuplicateModelException:
# In case a duplicate models sneaks by, we will ignore this error - we "found" the model
pass
return True
self._logger.info(f"Scanning {self._app_config.models_path} for orphaned models")
search = ModelSearch(on_model_found=on_model_found)
found_models = search.search(self._app_config.models_path)
self._logger.info(f"{len(found_models)} new models registered")
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
old_path = self.app_config.models_path / model.path
if not old_path.is_relative_to(models_dir):
# The model is not in the models directory - we don't need to move it.
return model
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
if new_path.exists():
raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists")
new_path.parent.mkdir(parents=True, exist_ok=True)
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
config = config or ModelRecordChanges()
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or ModelRecordChanges()
info = info or self._probe(model_path, config)
# Apply LoRA metadata if applicable
model_images_path = self.app_config.models_path / "model_images"
apply_lora_metadata(info, model_path.resolve(), model_images_path)
model_path = model_path.resolve()
# Models in the Invoke-managed models dir should use relative paths.
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
if isinstance(info, CheckpointConfigBase):
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
info.config_path = legacy_config_path.as_posix()
self.record_store.add_model(info)
return info.key
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or ModelRecordChanges(),
local_path=Path(source.path),
inplace=source.inplace or False,
)
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
if source.access_token is None:
source.access_token = HfFolder.get_token()
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
remote_files=remote_files,
metadata=metadata,
)
def _import_from_url(
self,
source: URLModelSource,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
metadata=metadata,
remote_files=remote_files,
)
def _import_remote_model(
self,
source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[ModelRecordChanges],
) -> ModelInstallJob:
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
destdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
)
)
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or ModelRecordChanges(),
source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
# remember the temporary directory for later removal
install_job._install_tmpdir = destdir
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
multifile_job = self._multifile_download(
remote_files=remote_files,
dest=destdir,
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
)
self._download_cache[multifile_job.id] = install_job
install_job._multifile_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
self._download_queue.submit_multifile_download(multifile_job)
return install_job
def _stat_size(self, path: Path) -> int:
size = 0
if path.is_file():
size = path.stat().st_size
elif path.is_dir():
for root, _, files in os.walk(path):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
def _multifile_download(
self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
if install_job := self._download_cache.get(download_job.id, None):
install_job.status = InstallStatus.DOWNLOADING
if install_job.local_path == install_job._install_tmpdir: # first time
assert download_job.download_path
install_job.local_path = download_job.download_path
install_job.download_parts = download_job.download_parts
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = download_job.total_bytes
self._signal_job_download_started(install_job)
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
if install_job := self._download_cache.get(download_job.id, None):
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._download_queue.cancel_job(download_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
if install_job := self._download_cache.pop(download_job.id, None):
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
if install_job := self._download_cache.pop(download_job.id, None):
assert excp is not None
self._set_error(install_job, excp)
self._download_queue.cancel_job(download_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
if install_job := self._download_cache.pop(download_job.id, None):
self._downloads_changed_event.set()
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_started(job)
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
assert job.error_type is not None
assert job.error is not None
self._event_bus.emit_model_install_error(job)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(job)
@staticmethod
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")