mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-18 06:42:00 -05:00
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
413 lines
14 KiB
Python
413 lines
14 KiB
Python
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
|
"""Baseclass definitions for the model installer."""
|
|
|
|
import re
|
|
import traceback
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
|
|
|
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
|
from pydantic.networks import AnyHttpUrl
|
|
from typing_extensions import Annotated
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
|
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
|
from invokeai.app.services.events import EventServiceBase
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
|
|
|
|
|
class InstallStatus(str, Enum):
|
|
"""State of an install job running in the background."""
|
|
|
|
WAITING = "waiting" # waiting to be dequeued
|
|
DOWNLOADING = "downloading" # downloading of model files in process
|
|
RUNNING = "running" # being processed
|
|
COMPLETED = "completed" # finished running
|
|
ERROR = "error" # terminated with an error message
|
|
CANCELLED = "cancelled" # terminated with an error message
|
|
|
|
|
|
class ModelInstallPart(BaseModel):
|
|
url: AnyHttpUrl
|
|
path: Path
|
|
bytes: int = 0
|
|
total_bytes: int = 0
|
|
|
|
|
|
class UnknownInstallJobException(Exception):
|
|
"""Raised when the status of an unknown job is requested."""
|
|
|
|
|
|
class StringLikeSource(BaseModel):
|
|
"""
|
|
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
|
|
|
These shenanigans let this stuff work:
|
|
|
|
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
|
mydict = {source1: 'model 1'}
|
|
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
|
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
|
|
|
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
|
assert source1 == source2
|
|
assert source1 == 'C:/users/mort/foo.safetensors'
|
|
"""
|
|
|
|
def __hash__(self) -> int:
|
|
"""Return hash of the path field, for indexing."""
|
|
return hash(str(self))
|
|
|
|
def __lt__(self, other: object) -> int:
|
|
"""Return comparison of the stringified version, for sorting."""
|
|
return str(self) < str(other)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
"""Return equality on the stringified version."""
|
|
if isinstance(other, Path):
|
|
return str(self) == other.as_posix()
|
|
else:
|
|
return str(self) == str(other)
|
|
|
|
|
|
class LocalModelSource(StringLikeSource):
|
|
"""A local file or directory path."""
|
|
|
|
path: str | Path
|
|
inplace: Optional[bool] = False
|
|
type: Literal["local"] = "local"
|
|
|
|
# these methods allow the source to be used in a string-like way,
|
|
# for example as an index into a dict
|
|
def __str__(self) -> str:
|
|
"""Return string version of path when string rep needed."""
|
|
return Path(self.path).as_posix()
|
|
|
|
|
|
class CivitaiModelSource(StringLikeSource):
|
|
"""A Civitai version id, with optional variant and access token."""
|
|
|
|
version_id: int
|
|
variant: Optional[ModelRepoVariant] = None
|
|
access_token: Optional[str] = None
|
|
type: Literal["civitai"] = "civitai"
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string version of repoid when string rep needed."""
|
|
base: str = str(self.version_id)
|
|
base += f" ({self.variant})" if self.variant else ""
|
|
return base
|
|
|
|
|
|
class HFModelSource(StringLikeSource):
|
|
"""
|
|
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
|
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
|
what people (almost) always want.
|
|
"""
|
|
|
|
repo_id: str
|
|
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
|
subfolder: Optional[Path] = None
|
|
access_token: Optional[str] = None
|
|
type: Literal["hf"] = "hf"
|
|
|
|
@field_validator("repo_id")
|
|
@classmethod
|
|
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
|
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
|
raise ValueError(f"{v}: invalid repo_id format")
|
|
return v
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string version of repoid when string rep needed."""
|
|
base: str = self.repo_id
|
|
base += f":{self.subfolder}" if self.subfolder else ""
|
|
base += f" ({self.variant})" if self.variant else ""
|
|
return base
|
|
|
|
|
|
class URLModelSource(StringLikeSource):
|
|
"""A generic URL point to a checkpoint file."""
|
|
|
|
url: AnyHttpUrl
|
|
access_token: Optional[str] = None
|
|
type: Literal["url"] = "url"
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string version of the url when string rep needed."""
|
|
return str(self.url)
|
|
|
|
|
|
ModelSource = Annotated[
|
|
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
|
]
|
|
|
|
|
|
class ModelInstallJob(BaseModel):
|
|
"""Object that tracks the current status of an install request."""
|
|
|
|
id: int = Field(description="Unique ID for this job")
|
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
|
config_in: Dict[str, Any] = Field(
|
|
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
|
)
|
|
config_out: Optional[AnyModelConfig] = Field(
|
|
default=None, description="After successful installation, this will hold the configuration object."
|
|
)
|
|
inplace: bool = Field(
|
|
default=False, description="Leave model in its current location; otherwise install under models directory"
|
|
)
|
|
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
|
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
|
bytes: Optional[int] = Field(
|
|
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
|
)
|
|
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
|
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
|
default=None, description="Metadata provided by the model source"
|
|
)
|
|
download_parts: Set[DownloadJob] = Field(
|
|
default_factory=set, description="Download jobs contributing to this install"
|
|
)
|
|
# internal flags and transitory settings
|
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
|
|
|
def set_error(self, e: Exception) -> None:
|
|
"""Record the error and traceback from an exception."""
|
|
self._exception = e
|
|
self.status = InstallStatus.ERROR
|
|
|
|
def cancel(self) -> None:
|
|
"""Call to cancel the job."""
|
|
self.status = InstallStatus.CANCELLED
|
|
|
|
@property
|
|
def error_type(self) -> Optional[str]:
|
|
"""Class name of the exception that led to status==ERROR."""
|
|
return self._exception.__class__.__name__ if self._exception else None
|
|
|
|
@property
|
|
def error(self) -> Optional[str]:
|
|
"""Error traceback."""
|
|
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
|
|
|
@property
|
|
def cancelled(self) -> bool:
|
|
"""Set status to CANCELLED."""
|
|
return self.status == InstallStatus.CANCELLED
|
|
|
|
@property
|
|
def errored(self) -> bool:
|
|
"""Return true if job has errored."""
|
|
return self.status == InstallStatus.ERROR
|
|
|
|
@property
|
|
def waiting(self) -> bool:
|
|
"""Return true if job is waiting to run."""
|
|
return self.status == InstallStatus.WAITING
|
|
|
|
@property
|
|
def downloading(self) -> bool:
|
|
"""Return true if job is downloading."""
|
|
return self.status == InstallStatus.DOWNLOADING
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
"""Return true if job is running."""
|
|
return self.status == InstallStatus.RUNNING
|
|
|
|
@property
|
|
def complete(self) -> bool:
|
|
"""Return true if job completed without errors."""
|
|
return self.status == InstallStatus.COMPLETED
|
|
|
|
@property
|
|
def in_terminal_state(self) -> bool:
|
|
"""Return true if job is in a terminal state."""
|
|
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
|
|
|
|
|
class ModelInstallServiceBase(ABC):
|
|
"""Abstract base class for InvokeAI model installation."""
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
app_config: InvokeAIAppConfig,
|
|
record_store: ModelRecordServiceBase,
|
|
download_queue: DownloadQueueServiceBase,
|
|
metadata_store: ModelMetadataStore,
|
|
event_bus: Optional["EventServiceBase"] = None,
|
|
):
|
|
"""
|
|
Create ModelInstallService object.
|
|
|
|
:param config: Systemwide InvokeAIAppConfig.
|
|
:param store: Systemwide ModelConfigStore
|
|
:param event_bus: InvokeAI event bus for reporting events to.
|
|
"""
|
|
|
|
# make the invoker optional here because we don't need it and it
|
|
# makes the installer harder to use outside the web app
|
|
@abstractmethod
|
|
def start(self, invoker: Optional[Invoker] = None) -> None:
|
|
"""Start the installer service."""
|
|
|
|
@abstractmethod
|
|
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
|
"""Stop the model install service. After this the objection can be safely deleted."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def app_config(self) -> InvokeAIAppConfig:
|
|
"""Return the appConfig object associated with the installer."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def record_store(self) -> ModelRecordServiceBase:
|
|
"""Return the ModelRecoreService object associated with the installer."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def event_bus(self) -> Optional[EventServiceBase]:
|
|
"""Return the event service base object associated with the installer."""
|
|
|
|
@abstractmethod
|
|
def register_path(
|
|
self,
|
|
model_path: Union[Path, str],
|
|
config: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""
|
|
Probe and register the model at model_path.
|
|
|
|
This keeps the model in its current location.
|
|
|
|
:param model_path: Filesystem Path to the model.
|
|
:param config: Dict of attributes that will override autoassigned values.
|
|
:returns id: The string ID of the registered model.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def unregister(self, key: str) -> None:
|
|
"""Remove model with indicated key from the database."""
|
|
|
|
@abstractmethod
|
|
def delete(self, key: str) -> None:
|
|
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
|
|
|
|
@abstractmethod
|
|
def unconditionally_delete(self, key: str) -> None:
|
|
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
|
|
|
|
@abstractmethod
|
|
def install_path(
|
|
self,
|
|
model_path: Union[Path, str],
|
|
config: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""
|
|
Probe, register and install the model in the models directory.
|
|
|
|
This moves the model from its current location into
|
|
the models directory handled by InvokeAI.
|
|
|
|
:param model_path: Filesystem Path to the model.
|
|
:param config: Dict of attributes that will override autoassigned values.
|
|
:returns id: The string ID of the registered model.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def import_model(
|
|
self,
|
|
source: ModelSource,
|
|
config: Optional[Dict[str, Any]] = None,
|
|
) -> ModelInstallJob:
|
|
"""Install the indicated model.
|
|
|
|
:param source: ModelSource object
|
|
|
|
:param config: Optional dict. Any fields in this dict
|
|
will override corresponding autoassigned probe fields in the
|
|
model's config record. Use it to override
|
|
`name`, `description`, `base_type`, `model_type`, `format`,
|
|
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
|
|
|
This will download the model located at `source`,
|
|
probe it, and install it into the models directory.
|
|
This call is executed asynchronously in a separate
|
|
thread and will issue the following events on the event bus:
|
|
|
|
- model_install_started
|
|
- model_install_error
|
|
- model_install_completed
|
|
|
|
The `inplace` flag does not affect the behavior of downloaded
|
|
models, which are always moved into the `models` directory.
|
|
|
|
The call returns a ModelInstallJob object which can be
|
|
polled to learn the current status and/or error message.
|
|
|
|
Variants recognized by HuggingFace currently are:
|
|
1. onnx
|
|
2. openvino
|
|
3. fp16
|
|
4. None (usually returns fp32 model)
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
|
|
|
@abstractmethod
|
|
def get_job_by_id(self, id: int) -> ModelInstallJob:
|
|
"""Return the ModelInstallJob corresponding to the provided id. Raises ValueError if no job has that ID."""
|
|
|
|
@abstractmethod
|
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
|
"""
|
|
List active and complete install jobs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def prune_jobs(self) -> None:
|
|
"""Prune all completed and errored jobs."""
|
|
|
|
@abstractmethod
|
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
|
"""Cancel the indicated job."""
|
|
|
|
@abstractmethod
|
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
|
"""
|
|
Wait for all pending installs to complete.
|
|
|
|
This will block until all pending installs have
|
|
completed, been cancelled, or errored out.
|
|
|
|
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
|
installs do not complete within the indicated time.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
|
"""
|
|
Recursively scan directory for new models and register or install them.
|
|
|
|
:param scan_dir: Path to the directory to scan.
|
|
:param install: Install if True, otherwise register in place.
|
|
:returns list of IDs: Returns list of IDs of models registered/installed
|
|
"""
|
|
|
|
@abstractmethod
|
|
def sync_to_config(self) -> None:
|
|
"""Synchronize models on disk to those in the model record database."""
|