Files
InvokeAI/invokeai/app/services/model_install/model_install_default.py
psychedelicious 8c15d14099 fix: use locale encoding
We have had a few bugs with v4 related to file encodings, especially on Windows.

Windows uses its own character encodings instead of `utf-8`, often `cp1252`. Some characters cannot be decoded using `utf-8`, causing `UnicodeDecodeError`.

There are a couple places where this can cause problems:
- In the installer bootstrap, we install or upgrade `pip` and decode the result, using `subprocess`.

  The input to this includes the user's home dir. In #6105, the user had one of the problematic characters in their username. `subprocess` attempts and fails to decode the username, which crashes the installer.

  To fix this, we need to use `locale.getpreferredencoding()` when executing the command.
- Similarly, in the model install service and config class, we attempt to load a yaml config file. If a problematic character is in the path to the file (which often includes the user's home dir), we can get the same error.

  One example is  #6129 in which the models.yaml migration fails.

  To fix this, we need to open the file with `locale.getpreferredencoding()`.
2024-04-04 15:30:47 +11:00

925 lines
40 KiB
Python

"""Model installation class."""
import locale
import os
import re
import signal
import threading
import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelRepoVariant,
ModelSourceType,
)
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
)
TMPDIR_PREFIX = "tmpinstall_"
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
"""
Initialize the installer object.
:param app_config: InvokeAIAppConfig object
:param record_store: Previously-opened ModelRecordService database
:param event_bus: Optional EventService object
"""
self._app_config = app_config
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs: List[ModelInstallJob] = []
self._install_queue: Queue[ModelInstallJob] = Queue()
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@property
def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
# makes the installer harder to use outside the web app
def start(self, invoker: Optional[Invoker] = None) -> None:
"""Start the installer thread."""
# Yes, this is weird. When the installer thread is running, the
# thread masks the ^C signal. When we receive a
# sigINT, we stop the thread, reset sigINT, and send a new
# sigINT to the parent process.
def sigint_handler(signum, frame):
self.stop()
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.raise_signal(signal.SIGINT)
signal.signal(signal.SIGINT, sigint_handler)
with self._lock:
if self._running:
raise Exception("Attempt to start the installer service twice")
self._start_installer_thread()
self._remove_dangling_install_dirs()
self._migrate_yaml()
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
# We should only do the scan when the flag is set (which should only be set when testing).
if self.app_config.scan_models_on_startup:
self._register_orphaned_models()
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
# want to alert the user.
for model in self._scan_for_missing_models():
self._logger.warning(f"Missing model file: {model.name} at {model.path}")
def stop(self, invoker: Optional[Invoker] = None) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
if not self._running:
raise Exception("Attempt to stop the install service before it was started")
self._logger.debug("calling stop_event.set()")
self._stop_event.set()
self._clear_pending_jobs()
self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning("Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:
job = self._install_queue.get(block=False)
self._install_queue.task_done()
except Empty:
break
def _put_in_queue(self, job: ModelInstallJob) -> None:
if self._stop_event.is_set():
self.cancel_job(job)
else:
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
return self._register(
new_path,
config,
info,
)
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
return similar_jobs[0]
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
install_job = self._import_from_url(source, config)
else:
raise ValueError(f"Unsupported model source: '{type(source)}'")
self._install_jobs.append(install_job)
return install_job
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102
jobs = [x for x in self._install_jobs if x.id == id]
if not jobs:
raise ValueError(f"No job with id {id} known")
assert len(jobs) == 1
assert isinstance(jobs[0], ModelInstallJob)
return jobs[0]
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._install_completed_event.wait(timeout=5): # in case we miss an event
self._install_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
self._install_queue.join()
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
with self._lock:
self._cancel_download_parts(job)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
# The old path may be relative to the root path
if not legacy_models_yaml_path.exists():
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
if legacy_models_yaml_path.exists():
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
legacy_models_yaml = yaml.safe_load(file)
yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
if yaml_version != "3.0.0":
raise ValueError(
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
)
self._logger.info(
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
)
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
for model_key, stanza in legacy_models_yaml.items():
_, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"])
if not model_path.is_absolute():
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
except Exception as e:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_relative_to(self.app_config.models_path):
# If the models is in the Invoke-managed models dir, we delete it
self.unconditionally_delete(key)
else:
# Else we only unregister it, leaving the file in place
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
elif model_path.is_dir():
rmtree(model_path)
self.unregister(key)
def download_and_cache(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.convert_cache_path / model_hash
# We expect the cache directory to contain one and only one downloaded file.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()]
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download(
source=AnyHttpUrl(str(source)),
dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
)
self._download_queue.wait_for_job(job, timeout)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True
def _install_next_item(self) -> None:
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
while True:
if self._stop_event.is_set():
break
self._logger.debug(f"Installer thread {threading.get_ident()} polling")
try:
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
try:
if job.cancelled:
self._signal_job_cancelled(job)
elif job.errored:
self._signal_job_errored(job)
elif job.waiting or job.downloads_done:
self._register_or_install(job)
except Exception as e:
# Expected errors include InvalidModelConfigException, DuplicateModelException, OSError, but we must
# gracefully handle _any_ error here.
self._set_error(job, e)
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
# --------------------------------------------------------------------------------------------
def _remove_dangling_install_dirs(self) -> None:
"""Remove leftover tmpdirs from aborted installs."""
path = self._app_config.models_path
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
rmtree(tmpdir)
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
"""Scan the models directory for missing models and return a list of them."""
missing_models: list[AnyModelConfig] = []
for model_config in self.record_store.all_models():
if not (self.app_config.models_path / model_config.path).resolve().exists():
missing_models.append(model_config)
return missing_models
def _register_orphaned_models(self) -> None:
"""Scan the invoke-managed models directory for orphaned models and registers them.
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
# The bool returned by this callback determines if the model is added to the list of models found by the search
def on_model_found(model_path: Path) -> bool:
resolved_path = model_path.resolve()
# Already registered models should be in the list of found models, but not re-registered.
if resolved_path in installed_model_paths:
return True
# Skip core models entirely - these aren't registered with the model manager.
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
return False
try:
model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}")
except DuplicateModelException:
# In case a duplicate models sneaks by, we will ignore this error - we "found" the model
pass
return True
self._logger.info(f"Scanning {self._app_config.models_path} for orphaned models")
search = ModelSearch(on_model_found=on_model_found)
found_models = search.search(self._app_config.models_path)
self._logger.info(f"{len(found_models)} new models registered")
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
old_path = self.app_config.models_path / model.path
if not old_path.is_relative_to(models_dir):
# The model is not in the models directory - we don't need to move it.
return model
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or {}
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
model_path = model_path.resolve()
# Models in the Invoke-managed models dir should use relative paths.
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
if isinstance(info, CheckpointConfigBase):
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
info.config_path = legacy_config_path.as_posix()
self.record_store.add_model(info)
return info.key
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace or False,
)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
return self._import_remote_model(
source=source,
config=config,
remote_files=remote_files,
metadata=metadata,
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially
metadata = None
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model(
source=source,
config=config,
metadata=metadata,
remote_files=remote_files,
)
def _import_remote_model(
self,
source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
)
)
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job
def _stat_size(self, path: Path) -> int:
size = 0
if path.is_file():
size = path.stat().st_size
elif path.is_dir():
for root, _, files in os.walk(path):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
with self._lock:
install_job = self._download_cache[download_job.source]
install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir:
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
dest_name = partial_path.parts[0]
install_job.local_path = install_job._install_tmpdir / dest_name
# Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
with self._lock:
install_job = self._download_cache[download_job.source]
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
assert install_job is not None
assert excp is not None
install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source))
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")