mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
28 Commits
v4.2.9.dev
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ffe672bc1 | ||
|
|
ed2d9ae0d9 | ||
|
|
09e7d35b55 | ||
|
|
9758082dc5 | ||
|
|
5f4ce0b118 | ||
|
|
8ac4b9b32c | ||
|
|
ec77599e79 | ||
|
|
2c1b8c0bc2 | ||
|
|
d4525e1282 | ||
|
|
b0d67ea2cc | ||
|
|
bd802d1e7a | ||
|
|
433eb73d8e | ||
|
|
b71f53ba86 | ||
|
|
68064c133a | ||
|
|
411ec1ed64 | ||
|
|
40a81c358d | ||
|
|
1d724bca4a | ||
|
|
a6508d1391 | ||
|
|
1eeca48529 | ||
|
|
79d028ecbd | ||
|
|
531d2c8fd7 | ||
|
|
37675ee4f5 | ||
|
|
26f721d0ec | ||
|
|
420f6050a6 | ||
|
|
9804cb0e67 | ||
|
|
4c5aedbcba | ||
|
|
a380d1f3b2 | ||
|
|
6b8a6e12bc |
@@ -28,7 +28,7 @@ model. These are the:
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
* _ModelLoadServiceBase_
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
@@ -41,10 +41,10 @@ The four main services can be found in
|
||||
* `invokeai/app/services/model_records/`
|
||||
* `invokeai/app/services/model_install/`
|
||||
* `invokeai/app/services/downloads/`
|
||||
* `invokeai/app/services/model_loader/` (**under development**)
|
||||
* `invokeai/app/services/model_load/`
|
||||
|
||||
Code related to the FastAPI web API can be found in
|
||||
`invokeai/app/api/routers/model_records.py`.
|
||||
`invokeai/app/api/routers/model_manager_v2.py`.
|
||||
|
||||
***
|
||||
|
||||
@@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
imported by, and can be reexported from,
|
||||
`invokeai.app.services.model_record_service`:
|
||||
`invokeai.app.services.model_manager.model_records`:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType
|
||||
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
|
||||
```
|
||||
|
||||
The `path` field can be absolute or relative. If relative, it is taken
|
||||
@@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
|
||||
|
||||
`variant` is an enumerated string class with values `normal`,
|
||||
`inpaint` and `depth`. If needed, it can be imported if needed from
|
||||
either `invokeai.app.services.model_record_service` or
|
||||
either `invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### ONNXSD2Config
|
||||
@@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
|
||||
| `upcast_attention` | bool | Model requires its attention module to be upcast |
|
||||
|
||||
The `SchedulerPredictionType` enum can be imported from either
|
||||
`invokeai.app.services.model_record_service` or
|
||||
`invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### Other config classes
|
||||
@@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
|
||||
models. This works OK for some models, such as the IP Adapter image
|
||||
encoders, but is an all-or-nothing proposition.
|
||||
|
||||
Another issue is that the config class hierarchy is paralleled to some
|
||||
extent by a `ModelBase` class hierarchy defined in
|
||||
`invokeai.backend.model_manager.models.base` and its subclasses. These
|
||||
are classes representing the models after they are loaded into RAM and
|
||||
include runtime information such as load status and bytes used. Some
|
||||
of the fields, including `name`, `model_type` and `base_model`, are
|
||||
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
||||
potential source of confusion.
|
||||
|
||||
## Reading and Writing Model Configuration Records
|
||||
|
||||
The `ModelRecordService` provides the ability to retrieve model
|
||||
@@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
|
||||
`InvocationContext` object:
|
||||
|
||||
```
|
||||
store = context.services.model_record_store
|
||||
store = context.services.model_manager.store
|
||||
```
|
||||
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_record_store`.
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
@@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
|
||||
`ModelRecordServiceFile` object:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
|
||||
store = ModelRecordServiceSQL.from_connection(connection, lock)
|
||||
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
|
||||
@@ -252,7 +243,7 @@ So a typical startup pattern would be:
|
||||
```
|
||||
import sqlite3
|
||||
from invokeai.app.services.thread import lock
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
@@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
|
||||
store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
```
|
||||
|
||||
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
|
||||
service architecture for the image and graph databases is careful to
|
||||
use a shared sqlite3 connection and a thread lock to ensure that two
|
||||
threads don't attempt to access the database simultaneously. However,
|
||||
the default `sqlite3` library used by Python reports using
|
||||
**Serialized** mode, which allows multiple threads to access the
|
||||
database simultaneously using multiple database connections (see
|
||||
https://www.sqlite.org/threadsafe.html and
|
||||
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
|
||||
it should be safe to allow the record service to open its own SQLite
|
||||
database connection. Opening a model record service should then be as
|
||||
simple as `ModelRecordServiceBase.open(config)`.
|
||||
|
||||
### Fetching a Model's Configuration from `ModelRecordServiceBase`
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
@@ -468,6 +446,44 @@ required parameters:
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
|
||||
This is a simplified interface to the installer which takes a source
|
||||
string, an optional model configuration dictionary and an optional
|
||||
access token.
|
||||
|
||||
The `source` is a string that can be any of these forms
|
||||
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
otherwise. If a variant that does not exist is requested, then the
|
||||
method will install whatever HuggingFace returns as its default
|
||||
revision.
|
||||
|
||||
`config` is an optional dict of values that will override the
|
||||
autoprobed values for model type, base, scheduler prediction type, and
|
||||
so forth. See [Model configuration and
|
||||
probing](#Model-configuration-and-probing) for details.
|
||||
|
||||
`access_token` is an optional access token for accessing resources
|
||||
that need authentication.
|
||||
|
||||
The method will return a `ModelInstallJob`. This object is discussed
|
||||
at length in the following section.
|
||||
|
||||
#### install_job = installer.import_model()
|
||||
|
||||
The `import_model()` method is the core of the installer. The
|
||||
@@ -486,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif
|
||||
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||
source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file
|
||||
|
||||
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
@@ -544,7 +561,6 @@ can be passed to `import_model()`.
|
||||
attributes returned by the model prober. See the section below for
|
||||
details.
|
||||
|
||||
|
||||
#### LocalModelSource
|
||||
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
@@ -737,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model confguration and probing
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
@@ -776,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will
|
||||
return from the call if jobs aren't completed in the specified
|
||||
time. An argument of 0 (the default) will block indefinitely.
|
||||
|
||||
#### jobs = installer.wait_for_job(job, [timeout])
|
||||
|
||||
Like `wait_for_installs()`, but block until a specific job has
|
||||
completed or errored, and then return the job. The optional `timeout`
|
||||
argument will return from the call if the job doesn't complete in the
|
||||
specified time. An argument of 0 (the default) will block
|
||||
indefinitely.
|
||||
|
||||
#### jobs = installer.list_jobs()
|
||||
|
||||
Return a list of all active and complete `ModelInstallJobs`.
|
||||
@@ -838,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
cache it, and return the path to the cached file. It does not attempt
|
||||
to determine the model type, probe its configuration values, or
|
||||
register it with the models database.
|
||||
|
||||
You may provide an access token if the remote source requires
|
||||
authorization. The call will block indefinitely until the file is
|
||||
completely downloaded, cancelled or raises an error of some sort. If
|
||||
you provide a timeout (in seconds), the call will raise a
|
||||
`TimeoutError` exception if the download hasn't completed in the
|
||||
specified period.
|
||||
|
||||
You may use this mechanism to request any type of file, not just a
|
||||
model. The file will be stored in a subdirectory of
|
||||
`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the
|
||||
cache, its path will be returned without redownloading it.
|
||||
|
||||
Be aware that the models cache is cleared of infrequently-used files
|
||||
and directories at regular intervals when the size of the cache
|
||||
exceeds the value specified in Invoke's `convert_cache` configuration
|
||||
variable.
|
||||
|
||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||
|
||||
This method will recursively scan the directory indicated in
|
||||
@@ -1128,7 +1177,7 @@ job = queue.create_download_job(
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
```
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
@@ -1171,6 +1220,13 @@ queue or was not created by this queue.
|
||||
This method will block until all the active jobs in the queue have
|
||||
reached a terminal state (completed, errored or cancelled).
|
||||
|
||||
#### queue.wait_for_job(job, [timeout])
|
||||
|
||||
This method will block until the indicated job has reached a terminal
|
||||
state (completed, errored or cancelled). If the optional timeout is
|
||||
provided, the call will block for at most timeout seconds, and raise a
|
||||
TimeoutError otherwise.
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
This will return a list of all jobs, including ones that have not yet
|
||||
@@ -1449,9 +1505,9 @@ set of keys to the corresponding model config objects.
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
# The remainder of this documentation is provisional, pending implementation of the Load service
|
||||
***
|
||||
|
||||
## Let's get loaded, the lowdown on ModelLoadService
|
||||
## The Lowdown on the ModelLoadService
|
||||
|
||||
The `ModelLoadService` is responsible for loading a named model into
|
||||
memory so that it can be used for inference. Despite the fact that it
|
||||
@@ -1465,7 +1521,7 @@ create alternative instances if you wish.
|
||||
### Creating a ModelLoadService object
|
||||
|
||||
The class is defined in
|
||||
`invokeai.app.services.model_loader_service`. It is initialized with
|
||||
`invokeai.app.services.model_load`. It is initialized with
|
||||
an InvokeAIAppConfig object, from which it gets configuration
|
||||
information such as the user's desired GPU and precision, and with a
|
||||
previously-created `ModelRecordServiceBase` object, from which it
|
||||
@@ -1475,26 +1531,29 @@ Here is a typical initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadService
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
loader = ModelLoadService(config, store)
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
Note that we are relying on the contents of the application
|
||||
configuration to choose the implementation of
|
||||
`ModelRecordServiceBase`.
|
||||
### load_model(model_config, [submodel_type], [context]) -> LoadedModel
|
||||
|
||||
### get_model(key, [submodel_type], [context]) -> ModelInfo:
|
||||
|
||||
*** TO DO: change to get_model(key, context=None, **kwargs)
|
||||
|
||||
The `get_model()` method, like its similarly-named cousin in
|
||||
`ModelRecordService`, receives the unique key that identifies the
|
||||
The `load_model()` method takes an `AnyModelConfig` returned by
|
||||
`ModelRecordService.get_model()` and returns the corresponding loaded
|
||||
model. It loads the model into memory, gets the model ready for use,
|
||||
and returns a `ModelInfo` object.
|
||||
and returns a `LoadedModel` object.
|
||||
|
||||
The optional second argument, `subtype` is a `SubModelType` string
|
||||
enum, such as "vae". It is mandatory when used with a main model, and
|
||||
@@ -1504,46 +1563,45 @@ The optional third argument, `context` can be provided by
|
||||
an invocation to trigger model load event reporting. See below for
|
||||
details.
|
||||
|
||||
The returned `ModelInfo` object shares some fields in common with
|
||||
`ModelConfigBase`, but is otherwise a completely different beast:
|
||||
The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | The model key derived from the ModelRecordService database |
|
||||
| `name` | str | Name of this model |
|
||||
| `base_model` | BaseModelType | Base model for this model |
|
||||
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
|
||||
| `location` | Path or str | Location of the model on the filesystem |
|
||||
| `precision` | torch.dtype | The torch.precision to use for inference |
|
||||
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
The types for `ModelInfo` and `SubModelType` can be imported from
|
||||
`invokeai.app.services.model_loader_service`.
|
||||
Because the loader can return multiple model types, it is typed to
|
||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
To use the model, you use the `ModelInfo` as a context manager using
|
||||
the following pattern:
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
the model. Use it like this:
|
||||
|
||||
```
|
||||
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The `vae` model will stay locked in the GPU during the period of time
|
||||
it is in the context manager's scope.
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
`get_model()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `InvalidModelException` -- the model is guilty of a variety of sins
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
** TO DO: ** Resolve discrepancy between ModelInfo.location and
|
||||
ModelConfig.path.
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `get_model()`, it will
|
||||
When the `context` argument is passed to `load_model_*()`, it will
|
||||
retrieve the invocation event bus from the passed `InvocationContext`
|
||||
object to emit events on the invocation bus. The two events are
|
||||
"model_load_started" and "model_load_completed". Both carry the
|
||||
@@ -1556,10 +1614,174 @@ payload=dict(
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
### Adding Model Loaders
|
||||
|
||||
Model loaders are small classes that inherit from the `ModelLoader`
|
||||
base class. They typically implement one method `_load_model()` whose
|
||||
signature is:
|
||||
|
||||
```
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
```
|
||||
|
||||
`_load_model()` will be passed the path to the model on disk, an
|
||||
optional repository variant (used by the diffusers loaders to select,
|
||||
e.g. the `fp16` variant, and an optional submodel_type for main and
|
||||
onnx models.
|
||||
|
||||
To install a new loader, place it in
|
||||
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
|
||||
`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
|
||||
indicate what type of models the loader can handle.
|
||||
|
||||
Here is a complete example from `generic_diffusers.py`, which is able
|
||||
to load several different diffusers types:
|
||||
|
||||
```
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self._get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
||||
```
|
||||
|
||||
Note that a loader can register itself to handle several different
|
||||
model types. An exception will be raised if more than one loader tries
|
||||
to register the same model type.
|
||||
|
||||
#### Conversion
|
||||
|
||||
Some models require conversion to diffusers format before they can be
|
||||
loaded. These loaders should override two additional methods:
|
||||
|
||||
```
|
||||
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
|
||||
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
```
|
||||
|
||||
The first method accepts the model configuration, the path to where
|
||||
the unmodified model is currently installed, and a proposed
|
||||
destination for the converted model. This method returns True if the
|
||||
model needs to be converted. It typically does this by comparing the
|
||||
last modification time of the original model file to the modification
|
||||
time of the converted model. In some cases you will also want to check
|
||||
the modification date of the configuration record, in the event that
|
||||
the user has changed something like the scheduler prediction type that
|
||||
will require the model to be re-converted. See `controlnet.py` for an
|
||||
example of this logic.
|
||||
|
||||
The second method accepts the model configuration, the path to the
|
||||
original model on disk, and the desired output path for the converted
|
||||
model. It does whatever it needs to do to get the model into diffusers
|
||||
format, and returns the Path of the resulting model. (The path should
|
||||
ordinarily be the same as `output_path`.)
|
||||
|
||||
## The ModelManagerService object
|
||||
|
||||
For convenience, the API provides a `ModelManagerService` object which
|
||||
gives a single point of access to the major model manager
|
||||
services. This object is created at initialization time and can be
|
||||
found in the global `ApiDependencies.invoker.services.model_manager`
|
||||
object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
||||
The following properties and methods will be available:
|
||||
|
||||
### mm.store
|
||||
|
||||
This retrieves the `ModelRecordService` associated with the
|
||||
manager. Example:
|
||||
|
||||
```
|
||||
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||
```
|
||||
|
||||
### mm.install
|
||||
|
||||
This retrieves the `ModelInstallService` associated with the manager.
|
||||
Example:
|
||||
|
||||
```
|
||||
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
```
|
||||
|
||||
### mm.load
|
||||
|
||||
This retrieves the `ModelLoaderService` associated with the manager. Example:
|
||||
|
||||
```
|
||||
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||
assert len(configs) > 0
|
||||
|
||||
loaded_model = mm.load.load_model(configs[0])
|
||||
```
|
||||
|
||||
The model manager also offers a few convenience shortcuts for loading
|
||||
models:
|
||||
|
||||
### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel
|
||||
|
||||
Same as `mm.load.load_model()`.
|
||||
|
||||
### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This accepts the combination of the model's name, type and base, which
|
||||
it passes to the model record config store for retrieval. If a unique
|
||||
model config is found, this method returns a `LoadedModel`. It can
|
||||
raise the following exceptions:
|
||||
|
||||
```
|
||||
UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- the loader doesn't know how to load this type of model
|
||||
ValueError -- more than one model matches this combination of base/type/name
|
||||
```
|
||||
|
||||
### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This method takes a model key, looks it up using the
|
||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||
model configuration to `load_model_by_config()`. It may raise a
|
||||
`NotImplementedException`.
|
||||
|
||||
@@ -4,7 +4,6 @@ from logging import Logger
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -25,8 +24,8 @@ from ..services.invocation_stats.invocation_stats_default import InvocationStats
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@@ -85,16 +84,13 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
metadata_store = ModelMetadataStore(db=db)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config,
|
||||
record_store=model_record_service,
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
events=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
@@ -120,9 +116,7 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
||||
@@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_downloads():
|
||||
async def prune_downloads() -> Response:
|
||||
"""Prune completed and errored jobs."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
queue.prune_jobs()
|
||||
@@ -55,7 +55,7 @@ async def download(
|
||||
) -> DownloadJob:
|
||||
"""Download the source URL to the file or directory indicted in dest."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
return queue.download(source, dest, priority, access_token)
|
||||
return queue.download(source, Path(dest), priority, access_token)
|
||||
|
||||
|
||||
@download_queue_router.get(
|
||||
@@ -87,7 +87,7 @@ async def get_download_job(
|
||||
)
|
||||
async def cancel_download_job(
|
||||
id: int = Path(description="ID of the download job to cancel."),
|
||||
):
|
||||
) -> Response:
|
||||
"""Cancel a download job using its ID."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
@@ -105,7 +105,7 @@ async def cancel_download_job(
|
||||
204: {"description": "Download jobs have been cancelled"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
async def cancel_all_download_jobs() -> Response:
|
||||
"""Cancel all download jobs."""
|
||||
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
759
invokeai/app/api/routers/model_manager.py
Normal file
759
invokeai/app/api/routers/model_manager.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: List[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
##############################################################################
|
||||
example_model_config = {
|
||||
"path": "string",
|
||||
"name": "string",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
"path": "/path/to/model",
|
||||
"name": "model_name",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model configuration was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
Delete model record from database.
|
||||
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/heuristic_import",
|
||||
operation_id="heuristic_import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def heuristic_import(
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
`source` can be any of the following.
|
||||
|
||||
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
|
||||
2. A Url pointing to a single downloadable model file
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- model/name
|
||||
- model/name:fp16:vae
|
||||
- model/name::vae -- use default precision
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
See the documentation for `import_model_record` for more information on
|
||||
interpreting the job information returned by this route.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}
|
||||
```
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
```
|
||||
{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
* "model_install_running"
|
||||
* "model_install_completed"
|
||||
* "model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
the nature of the job and its progress. The `status` is one of:
|
||||
|
||||
* "waiting" -- Job is waiting in the queue to run
|
||||
* "downloading" -- Model file(s) are downloading
|
||||
* "running" -- Model has downloaded and the model probing and registration process is running
|
||||
* "completed" -- Installation completed successfully
|
||||
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""
|
||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
for information on the format of the return value.
|
||||
"""
|
||||
try:
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
204: {"description": "Model config record database resynced with files on disk"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def convert_model(
|
||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Permanently convert a model into diffusers format, replacing the safetensors version.
|
||||
Note that during the conversion process the key and model hash will change.
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
try:
|
||||
model_config = store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
|
||||
if not isinstance(model_config, MainCheckpointConfig):
|
||||
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
assert cache_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
cache_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
# delete the cached version
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config: AnyModelConfig = store.get_model(new_key)
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -1,472 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import pathlib
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: List[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
Delete model record from database.
|
||||
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}`
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
`{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
"model_install_running"
|
||||
"model_install_completed"
|
||||
"model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
204: {"description": "Model config record database resynced with files on disk"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models.
|
||||
|
||||
keys: List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name: Name for the merged model [Concat model names]
|
||||
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory: Specify a directory to store the merged model in [models directory]
|
||||
"""
|
||||
print(f"here i am, keys={keys}")
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -1,427 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
import pathlib
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList}},
|
||||
)
|
||||
async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = []
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info_dict,
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||
default=None,
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
location = location.strip("\"' ")
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import,
|
||||
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes=info.model_dump(),
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: {"description": "Directory searched successfully"},
|
||||
404: {"description": "Invalid directory path"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: {"description": "paths retrieved successfully"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: {"description": "synchronization successful"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||
class MergeModelsBody(BaseModel):
|
||||
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||
force: Optional[bool] = Field(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
)
|
||||
|
||||
merge_dest_directory: Optional[str] = Field(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(
|
||||
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||
)
|
||||
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names=body.model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||
alpha=body.alpha,
|
||||
interp=body.interp,
|
||||
force=body.force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{body.model_names}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -47,8 +47,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_records,
|
||||
models,
|
||||
model_manager,
|
||||
session_queue,
|
||||
sessions,
|
||||
utilities,
|
||||
@@ -115,8 +114,7 @@ async def shutdown_event() -> None:
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
app.include_router(model_records.model_records_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
@@ -178,21 +176,23 @@ def custom_openapi() -> dict[str, Any]:
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
openapi_schema["components"]["schemas"][name] = {
|
||||
"title": name,
|
||||
"description": "An enumeration.",
|
||||
"type": "string",
|
||||
"enum": [v.value for v in model_config_format_enum],
|
||||
}
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -66,21 +71,22 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -90,25 +96,20 @@ class CompelInvocation(BaseInvocation):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
loaded_model = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -116,7 +117,7 @@ class CompelInvocation(BaseInvocation):
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@@ -150,7 +151,7 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
context.services.latents.save(conditioning_name, conditioning_data) # TODO: fix type mismatch here
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
@@ -160,6 +161,8 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
"""Prompt processor for SDXL models."""
|
||||
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -168,26 +171,27 @@ class SDXLPromptInvocationBase:
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
cpu_text_encoder = text_encoder_info.model
|
||||
assert isinstance(cpu_text_encoder, torch.nn.Module)
|
||||
c = torch.zeros(
|
||||
(
|
||||
1,
|
||||
cpu_text_encoder.config.max_position_embeddings,
|
||||
cpu_text_encoder.config.hidden_size,
|
||||
),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
dtype=cpu_text_encoder.dtype,
|
||||
)
|
||||
if get_pooled:
|
||||
c_pooled = torch.zeros(
|
||||
@@ -198,12 +202,14 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled, None
|
||||
|
||||
def _lora_loader():
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -213,25 +219,24 @@ class SDXLPromptInvocationBase:
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
ti_model = context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -239,7 +244,7 @@ class SDXLPromptInvocationBase:
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@@ -357,6 +362,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
assert c2_pooled is not None
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
@@ -410,6 +416,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
assert c2_pooled is not None
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
@@ -459,9 +466,9 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer,
|
||||
tokenizer: CLIPTokenizer,
|
||||
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False,
|
||||
truncate_if_too_long: bool = False,
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
@@ -473,7 +480,9 @@ def get_max_token_count(
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
|
||||
) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
@@ -486,24 +495,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens: List[str] = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
assert display_label_prefix is not None
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
@@ -543,7 +557,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(
|
||||
text: str,
|
||||
tokenizer: CLIPTokenizer,
|
||||
display_label: Optional[str] = None,
|
||||
truncate_if_too_long: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
||||
@@ -24,7 +24,7 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -32,7 +32,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -57,10 +56,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -17,22 +17,16 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
@@ -49,12 +43,12 @@ class IPAdapterField(BaseModel):
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
def validate_ip_adapter_weight(cls, v):
|
||||
def validate_ip_adapter_weight(cls, v: float) -> float:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
def validate_begin_end_step_percent(self) -> Self:
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
@@ -87,33 +81,25 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
def validate_ip_adapter_weight(cls, v):
|
||||
def validate_ip_adapter_weight(cls, v: float) -> float:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
def validate_begin_end_step_percent(self) -> Self:
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.model_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
)
|
||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
|
||||
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
|
||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||
# disk vs. downloading the model.
|
||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
||||
)
|
||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = CLIPVisionModelField(
|
||||
model_name=image_encoder_model_name,
|
||||
base_model=BaseModelType.Any,
|
||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
@@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@@ -39,13 +43,13 @@ from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
@@ -77,7 +81,9 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(SCHEDULER_MAP.keys())
|
||||
] # FIXME: "Invalid type alias". This defeats static type checking.
|
||||
|
||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
@@ -131,10 +137,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
# if shape is not None:
|
||||
@@ -145,24 +151,24 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
else:
|
||||
image = None
|
||||
image_tensor = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.services.images.get_pil_image(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
if image_tensor is not None:
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
@@ -189,7 +195,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
orig_scheduler_info = context.services.model_manager.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@@ -200,7 +206,7 @@ def get_scheduler(
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
**scheduler_extra_config, # FIXME
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
@@ -213,6 +219,7 @@ def get_scheduler(
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
assert isinstance(scheduler, Scheduler)
|
||||
return scheduler
|
||||
|
||||
|
||||
@@ -296,7 +303,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
@@ -326,9 +333,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -351,7 +358,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
@@ -363,8 +370,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
@@ -375,10 +382,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
@@ -395,11 +402,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
@@ -422,10 +429,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
@@ -490,27 +495,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
image_encoder_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_images = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_images, list):
|
||||
single_ipa_images = [single_ipa_images]
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@@ -554,23 +557,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
t2i_adapter_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
@@ -593,7 +592,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
num_channels=t2i_adapter_model.config.in_channels,
|
||||
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
@@ -618,7 +617,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
@@ -630,11 +637,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
_timesteps = timesteps[:: scheduler.order]
|
||||
|
||||
# get start timestep index
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
||||
|
||||
# get end timestep index
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
||||
|
||||
# apply order to indexes
|
||||
@@ -647,7 +654,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
@@ -700,31 +709,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
||||
|
||||
def _lora_loader():
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
unet_info = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -822,12 +836,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -1016,8 +1031,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -1063,7 +1079,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@@ -1082,14 +1098,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||
return latents
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -1122,7 +1143,12 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
@@ -1155,12 +1181,16 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
@@ -1256,15 +1286,19 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
|
||||
)
|
||||
|
||||
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
|
||||
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
aspect = self.width / self.height
|
||||
dimension = 512
|
||||
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||
dimension = 1024
|
||||
dimension = dimension * self.multiplier
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -20,12 +20,8 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
@@ -55,7 +51,7 @@ class VaeField(BaseModel):
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field"""
|
||||
"""Base class for invocations that output a UNet field."""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
|
||||
@@ -84,20 +80,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -114,85 +103,40 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -229,21 +173,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@@ -251,10 +190,8 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -263,10 +200,8 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -318,24 +253,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@@ -343,10 +273,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -355,10 +283,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -367,10 +293,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -381,10 +305,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
@@ -398,25 +319,12 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VAEOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
|
||||
@@ -8,16 +8,16 @@ from typing import List, Literal, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import ModelType, SubModelType
|
||||
from invokeai.backend.model_patcher import ONNXModelPatcher
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
@@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
@@ -84,11 +84,11 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
).model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
@@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
unet_info = context.services.model_manager.load_model_by_key(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
@@ -344,9 +344,9 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
@@ -400,11 +400,7 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model ID")
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
@@ -416,93 +412,46 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -368,7 +368,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput:
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -44,72 +44,52 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -133,56 +113,40 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -16,14 +16,10 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESI
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
|
||||
@@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
argparse_groups: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
@@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
@@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
result: str = get_args(hints[command_field])[0]
|
||||
return result
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
|
||||
@@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
@@ -185,7 +185,9 @@ from .config_base import InvokeAISettings
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
|
||||
|
||||
class Categories(object):
|
||||
@@ -237,6 +239,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
@@ -260,8 +263,10 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
# CACHE
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
@@ -404,6 +409,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def models_convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory."""
|
||||
return self._resolve(self.convert_cache_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""Path to the custom nodes directory."""
|
||||
@@ -433,15 +443,20 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the ram cache size using the legacy or modern setting."""
|
||||
def ram_cache_size(self) -> float:
|
||||
"""Return the ram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the vram cache size using the legacy or modern setting."""
|
||||
def vram_cache_size(self) -> float:
|
||||
"""Return the vram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def convert_cache_size(self) -> float:
|
||||
"""Return the convert cache size on disk (GB)."""
|
||||
return self.convert_cache
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||
|
||||
@@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
|
||||
def join(self) -> None:
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@@ -48,11 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._jobs = {}
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._worker_pool = set()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._worker_pool: Set[threading.Thread] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
self._event_bus = event_bus
|
||||
@@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
def _start_workers(self, max_workers: int) -> None:
|
||||
"""Start the requested number of worker threads."""
|
||||
self._stop_event.clear()
|
||||
@@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
self._queue.task_done()
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
@@ -407,11 +420,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
# Example on_progress event handler to display a TQDM status bar
|
||||
# Activate with:
|
||||
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
|
||||
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
|
||||
class TqdmProgress(object):
|
||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_bars: Dict[int, tqdm] # type: ignore
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self) -> None: # noqa D107
|
||||
|
||||
@@ -11,8 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
@@ -171,10 +170,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_config: AnyModelConfig,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
@@ -184,10 +180,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -197,11 +190,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
model_config: AnyModelConfig,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
@@ -211,13 +200,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"hash": model_info.hash,
|
||||
"location": str(model_info.location),
|
||||
"precision": str(model_info.precision),
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@@ -50,9 +48,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_install: "ModelInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@@ -78,9 +74,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@@ -104,9 +98,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.download_queue = download_queue
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
@@ -29,8 +29,8 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||
@@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
@@ -61,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Reset all statistics for the indicated graph.
|
||||
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||
@@ -70,7 +69,7 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self, graph_execution_state_id: str):
|
||||
def log_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
:param graph_execution_state_id: The id of the session whose stats to log.
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -10,7 +11,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
|
||||
from .invocation_stats_base import InvocationStatsServiceBase
|
||||
from .invocation_stats_common import (
|
||||
@@ -41,7 +42,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
if not self._stats.get(graph_execution_state_id):
|
||||
# First time we're seeing this graph_execution_state_id.
|
||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||
@@ -55,8 +59,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
start_ram = psutil.Process().memory_info().rss
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
if self._invoker.services.model_manager:
|
||||
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
|
||||
|
||||
assert services.model_manager.load is not None
|
||||
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
|
||||
try:
|
||||
# Let the invocation run.
|
||||
@@ -73,7 +78,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def _prune_stale_stats(self):
|
||||
def _prune_stale_stats(self) -> None:
|
||||
"""Check all graphs being tracked and prune any that have completed/errored.
|
||||
|
||||
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
@@ -20,8 +22,10 @@ class LatentsStorageBase(ABC):
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
# (LS) Added a Union with ConditioningFieldData to fix type mismatch errors in compel.py
|
||||
# Not 100% sure this isn't an existing bug.
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
@@ -27,7 +28,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
@@ -46,7 +47,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
# TODO: (LS) ConditioningFieldData added as Union because of type-checking errors
|
||||
# in compel.py. Unclear whether this is a long-standing bug, but seems to run.
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
@@ -14,11 +14,13 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@@ -127,8 +129,8 @@ class HFModelSource(StringLikeSource):
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
base += f":{self.variant or ''}"
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
@@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStore,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@@ -324,6 +326,43 @@ class ModelInstallServiceBase(ABC):
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
|
||||
The source can be:
|
||||
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
||||
2. An http or https URL (`https://foo.bar/foo`)
|
||||
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
||||
|
||||
We extend the HuggingFace repo_id syntax to include the variant and the
|
||||
subfolder or path. The following are acceptable alternatives:
|
||||
stabilityai/stable-diffusion-v4
|
||||
stabilityai/stable-diffusion-v4:fp16
|
||||
stabilityai/stable-diffusion-v4:fp16:vae
|
||||
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
stabilityai/stable-diffusion-v4:onnx:vae
|
||||
|
||||
Because a local file path can look like a huggingface repo_id, the logic
|
||||
first checks whether the path exists on disk, and if not, it is treated as
|
||||
a parseable huggingface repo.
|
||||
|
||||
The previous support for recursing into a local folder and loading all model-like files
|
||||
has been removed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_model(
|
||||
self,
|
||||
@@ -385,6 +424,18 @@ class ModelInstallServiceBase(ABC):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Wait for the indicated job to reach a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
||||
"""
|
||||
@@ -394,7 +445,8 @@ class ModelInstallServiceBase(ABC):
|
||||
completed, been cancelled, or errored out.
|
||||
|
||||
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
||||
installs do not complete within the indicated time.
|
||||
installs do not complete within the indicated time. A timeout of zero (the default)
|
||||
will block indefinitely until the installs complete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -410,3 +462,22 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
is periodically cleared of infrequently-used entries when the model
|
||||
converter runs.
|
||||
|
||||
Note that this doesn't automaticallly install or register the model, but is
|
||||
intended for use by nodes that need access to models that aren't directly
|
||||
supported by InvokeAI. The downloading process takes advantage of the download queue
|
||||
to avoid interrupting other operations.
|
||||
"""
|
||||
|
||||
@@ -17,10 +17,10 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
@@ -50,6 +49,7 @@ from .model_install_base import (
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
@@ -64,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: Optional[ModelMetadataStore] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
@@ -86,19 +85,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
# There may not necessarily be a metadata store initialized
|
||||
# so we create one and initialize it with the same sql database
|
||||
# used by the record store service.
|
||||
if metadata_store:
|
||||
self._metadata_store = metadata_store
|
||||
else:
|
||||
assert isinstance(record_store, ModelRecordServiceSQL)
|
||||
self._metadata_store = ModelMetadataStore(record_store.db)
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@@ -145,7 +138,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
return self._register(model_path, config)
|
||||
|
||||
@@ -156,12 +149,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
old_hash = info.current_hash
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
|
||||
)
|
||||
try:
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
except FileExistsError as excp:
|
||||
@@ -177,7 +172,40 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
return similar_jobs[0]
|
||||
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
@@ -207,14 +235,25 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert isinstance(jobs[0], ModelInstallJob)
|
||||
return jobs[0]
|
||||
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._install_completed_event.wait(timeout=5): # in case we miss an event
|
||||
self._install_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
while len(self._download_cache) > 0:
|
||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise Exception("Timeout exceeded")
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
@@ -268,6 +307,38 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.models_convert_cache_path / model_hash
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@@ -300,6 +371,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
@@ -330,6 +402,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# if this is an install of a remote file, then clean up the temporary directory
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
@@ -489,10 +562,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> ModelRepoVariant:
|
||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = choose_precision(choose_torch_device())
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
@@ -517,7 +590,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
@@ -565,6 +638,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
tmpdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
@@ -580,6 +655,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
@@ -589,7 +674,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = model_file.path
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.info(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
|
||||
6
invokeai/app/services/model_load/__init__.py
Normal file
6
invokeai/app/services/model_load/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Initialization file for model load service module."""
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
from .model_load_default import ModelLoadService
|
||||
|
||||
__all__ = ["ModelLoadServiceBase", "ModelLoadService"]
|
||||
40
invokeai/app/services/model_load/model_load_base.py
Normal file
40
invokeai/app/services/model_load/model_load_base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Base class for model loader."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
106
invokeai/app/services/model_load/model_load_default.py
Normal file
106
invokeai/app/services/model_load/model_load_default.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
|
||||
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Wrapper around ModelLoaderRegistry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
self._logger = logger
|
||||
self._app_config = app_config
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
@@ -1 +1,17 @@
|
||||
from .model_manager_default import ModelManagerService # noqa F401
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"LoadedModel",
|
||||
]
|
||||
|
||||
@@ -1,286 +1,67 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""Abstract base class for the model manager service."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
# attributes:
|
||||
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
db: SqliteDatabase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
Use it rather than the __init__ constructor. This class
|
||||
method simplifies the construction considerably.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,413 +1,149 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelManager,
|
||||
ModelMerger,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""
|
||||
The ModelManagerService handles various aspects of model installation, maintenance and loading.
|
||||
|
||||
It bundles three distinct services:
|
||||
model_manager.store -- Routines to manage the database of model configuration records.
|
||||
model_manager.install -- Routines to install, move and delete models.
|
||||
model_manager.load -- Routines to load models into memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
store: ModelRecordServiceBase,
|
||||
install: ModelInstallServiceBase,
|
||||
load: ModelLoadServiceBase,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
self._store = store
|
||||
self._install = install
|
||||
self._load = load
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
@property
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
return self._store
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
@property
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
return self._install
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
@property
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
return self._load
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "start"):
|
||||
service.start(invoker)
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context)
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context)
|
||||
|
||||
def get_model(
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
return self.load.load_model(configs[0], submodel, context)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
For simplicity, use this class method rather than the __init__ constructor.
|
||||
"""
|
||||
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry,
|
||||
)
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue,
|
||||
event_bus=events,
|
||||
)
|
||||
return cls(store=model_record_service, install=installer, load=loader)
|
||||
|
||||
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
||||
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
@@ -11,8 +11,15 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@@ -104,7 +111,7 @@ class ModelRecordServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@@ -146,7 +153,7 @@ class ModelRecordServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
Return True if a model with the indicated key exists in the database.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
|
||||
@@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
@@ -69,16 +70,16 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
:param db: Sqlite connection object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
@@ -158,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@@ -199,7 +200,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -207,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
@@ -265,12 +266,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@@ -279,12 +282,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
@@ -293,18 +298,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return ModelMetadataStore(self._db)
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
@@ -325,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
|
||||
@@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration6Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._recreate_model_triggers(cursor)
|
||||
self._delete_ip_adapters(cursor)
|
||||
|
||||
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Adds the timestamp trigger to the model_config table.
|
||||
|
||||
This trigger was inadvertently dropped in earlier migration scripts.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Delete all the IP adapters.
|
||||
|
||||
The model manager will automatically find and re-add them after the migration
|
||||
is done. This allows the manager to add the correct image encoder to their
|
||||
configuration records.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE type='ip_adapter';
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_6() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 5 to 6.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the model_config_updated_at trigger if it does not exist
|
||||
- Delete all ip_adapter models so that the model prober can find and
|
||||
update with the correct image processor model.
|
||||
"""
|
||||
migration_6 = Migration(
|
||||
from_version=5,
|
||||
to_version=6,
|
||||
callback=Migration6Callback(),
|
||||
)
|
||||
|
||||
return migration_6
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
def get_timestamp() -> int:
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
@@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
def get_random_seed() -> int:
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
|
||||
def uuid_string():
|
||||
def uuid_string() -> str:
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
def is_optional(value: typing.Any):
|
||||
def is_optional(value: typing.Any) -> bool:
|
||||
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||
|
||||
@@ -3,7 +3,7 @@ from PIL import Image
|
||||
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
||||
4
invokeai/backend/embeddings/__init__.py
Normal file
4
invokeai/backend/embeddings/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Initialization file for invokeai.backend.embeddings modules."""
|
||||
|
||||
# from .model_patcher import ModelPatcher
|
||||
# __all__ = ["ModelPatcher"]
|
||||
12
invokeai/backend/embeddings/embedding_base.py
Normal file
12
invokeai/backend/embeddings/embedding_base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Base class for LoRA and Textual Inversion models.
|
||||
|
||||
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher.
|
||||
|
||||
The use of "Raw" here is a historical artifact, and carried forward in
|
||||
order to avoid confusion.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingModelRaw:
|
||||
"""Base class for LoRA and Textual Inversion models."""
|
||||
@@ -8,8 +8,8 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
@@ -25,18 +25,20 @@ from invokeai.app.services.model_install import (
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
@@ -44,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return obj
|
||||
|
||||
|
||||
@@ -53,12 +55,10 @@ def initialize_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
metadata_store = record_store.metadata_store
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
metadata_store=metadata_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
@@ -98,11 +98,13 @@ class TqdmEventService(EventServiceBase):
|
||||
super().__init__()
|
||||
self._bars: Dict[str, tqdm] = {}
|
||||
self._last: Dict[str, int] = {}
|
||||
self._logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
data = payload["data"]
|
||||
source = data["source"]
|
||||
if payload["event"] == "model_install_downloading":
|
||||
data = payload["data"]
|
||||
dest = data["local_path"]
|
||||
total_bytes = data["total_bytes"]
|
||||
bytes = data["bytes"]
|
||||
@@ -111,6 +113,12 @@ class TqdmEventService(EventServiceBase):
|
||||
self._last[dest] = 0
|
||||
self._bars[dest].update(bytes - self._last[dest])
|
||||
self._last[dest] = bytes
|
||||
elif payload["event"] == "model_install_completed":
|
||||
self._logger.info(f"{source}: installed successfully.")
|
||||
elif payload["event"] == "model_install_error":
|
||||
self._logger.warning(f"{source}: installation failed with error {data['error']}")
|
||||
elif payload["event"] == "model_install_cancelled":
|
||||
self._logger.warning(f"{source}: installation cancelled")
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
@@ -225,11 +233,19 @@ class InstallHelper(object):
|
||||
|
||||
if model_path.exists(): # local file on disk
|
||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
|
||||
|
||||
# parsing huggingface repo ids
|
||||
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
||||
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
||||
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
||||
repo_id = match.group(1)
|
||||
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
||||
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
|
||||
return HFModelSource(
|
||||
repo_id=model_path_id_or_url,
|
||||
repo_id=repo_id,
|
||||
access_token=HfFolder.get_token(),
|
||||
subfolder=model_info.subfolder,
|
||||
subfolder=subfolder,
|
||||
variant=repo_variant,
|
||||
)
|
||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||
@@ -270,12 +286,14 @@ class InstallHelper(object):
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
self._logger.error(
|
||||
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||
)
|
||||
elif not matches:
|
||||
print(f"{model}: unknown model")
|
||||
self._logger.error(f"{model_to_remove}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.type}:{m.name}")
|
||||
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
|
||||
@@ -18,31 +18,30 @@ from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, get_args, get_type_hints
|
||||
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AutoencoderKL, ModelMixin
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import ValidationError
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
@@ -61,7 +60,7 @@ warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
def get_literal_fields(field) -> list[Any]:
|
||||
def get_literal_fields(field: str) -> Tuple[Any]:
|
||||
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||
|
||||
|
||||
@@ -80,8 +79,7 @@ ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0)
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
@@ -96,13 +94,15 @@ logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
def postscript(errors: Set[str]) -> None:
|
||||
if not any(errors):
|
||||
message = f"""
|
||||
** INVOKEAI INSTALLATION SUCCESSFUL **
|
||||
@@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
def HfLogin(access_token) -> None:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
@@ -162,7 +162,7 @@ def HfLogin(access_token) -> str:
|
||||
|
||||
# -------------------------------------
|
||||
class ProgressBar:
|
||||
def __init__(self, model_name="file"):
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
@@ -179,6 +179,22 @@ class ProgressBar:
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@@ -249,6 +265,7 @@ def download_conversion_models():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# TO DO: use the download queue here.
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
@@ -288,18 +305,19 @@ def download_lama():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
def download_support_models() -> None:
|
||||
download_realesrgan()
|
||||
download_lama()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
def get_root(root: Optional[str] = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
elif root := os.environ.get("INVOKEAI_ROOT"):
|
||||
assert root is not None
|
||||
return root
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
@@ -455,6 +473,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
@@ -495,6 +532,14 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@@ -506,19 +551,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -555,6 +602,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@@ -584,18 +635,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
else:
|
||||
return True
|
||||
|
||||
def marshall_arguments(self):
|
||||
def marshall_arguments(self) -> Namespace:
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"convert_cache",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
@@ -615,13 +669,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_selections = default_user_selections(program_opts)
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@@ -640,16 +695,10 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
def new_opts(self) -> Namespace:
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
@@ -660,27 +709,18 @@ def default_ramcache() -> float:
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
def default_startup_options(init_file: Path) -> InvokeAIAppConfig:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = opts.ram or default_ramcache()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
models = installer.all_models()
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else [],
|
||||
install_models=default_models if program_opts.yes_to_all else [],
|
||||
)
|
||||
|
||||
|
||||
@@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
if models_yaml.exists():
|
||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||
return
|
||||
else:
|
||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||
|
||||
with open(models_yaml, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path, install_helper: InstallHelper
|
||||
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
@@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
else:
|
||||
return (editApp.new_opts, editApp.install_selections)
|
||||
return (editApp.new_opts(), editApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
|
||||
"""
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
@@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key, value in opts.__dict__.items():
|
||||
for key, value in opts.model_dump().items():
|
||||
if hasattr(new_config, key):
|
||||
setattr(new_config, key, value)
|
||||
|
||||
@@ -779,7 +802,7 @@ def default_output_dir() -> Path:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
def write_default_options(program_opts: Namespace, initfile: Path) -> None:
|
||||
opt = default_startup_options(initfile)
|
||||
write_opts(opt, initfile)
|
||||
|
||||
@@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format: Path):
|
||||
def migrate_init_file(legacy_format: Path) -> None:
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [
|
||||
x
|
||||
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||
]
|
||||
for attr in fields:
|
||||
for attr in InvokeAIAppConfig.model_fields.keys():
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
@@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def migrate_models(root: Path):
|
||||
def migrate_models(root: Path) -> None:
|
||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||
|
||||
do_migrate(root, root)
|
||||
@@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
):
|
||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||
migrate_init_file(old_init_file)
|
||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||
omegaconf = OmegaConf.load(new_init_file)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
config.parse_args(argv=[], conf=omegaconf)
|
||||
|
||||
if old_hub.exists():
|
||||
migrate_models(config.root_path)
|
||||
@@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main() -> None:
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
@@ -908,6 +928,7 @@ def main() -> None:
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
@@ -921,14 +942,18 @@ def main() -> None:
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
# this will initialize the models.yaml file if not present
|
||||
install_helper = InstallHelper(config, logger)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
@@ -943,10 +968,12 @@ def main() -> None:
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
|
||||
elif models_to_download:
|
||||
process_and_execute(opt, models_to_download)
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
|
||||
@@ -1,591 +0,0 @@
|
||||
"""
|
||||
Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import diffusers
|
||||
import transformers
|
||||
import yaml
|
||||
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# holder for paths that we will migrate
|
||||
@dataclass
|
||||
class ModelPaths:
|
||||
models: Path
|
||||
embeddings: Path
|
||||
loras: Path
|
||||
controlnets: Path
|
||||
|
||||
|
||||
class MigrateTo3(object):
|
||||
def __init__(
|
||||
self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.src_paths = src_paths
|
||||
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, "w") as file:
|
||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
def create_directory_structure(self):
|
||||
"""
|
||||
Create the basic directory structure for the models folder.
|
||||
"""
|
||||
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
for model_type in [
|
||||
ModelType.Main,
|
||||
ModelType.Vae,
|
||||
ModelType.Lora,
|
||||
ModelType.ControlNet,
|
||||
ModelType.TextualInversion,
|
||||
]:
|
||||
path = self.dest_models / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = self.dest_models / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def copy_file(src: Path, dest: Path):
|
||||
"""
|
||||
copy a single file with logging
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copy(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def copy_dir(src: Path, dest: Path):
|
||||
"""
|
||||
Recursively copy a directory with logging
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copytree(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
def migrate_models(self, src_dir: Path):
|
||||
"""
|
||||
Recursively walk through src directory, probe anything
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root, d)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||
continue
|
||||
model = Path(root, f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
Copy the clipseg, upscaler, and restoration models to their new
|
||||
locations.
|
||||
"""
|
||||
dest_directory = self.dest_models
|
||||
if (self.root_directory / "models/clipseg").exists():
|
||||
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
||||
if (self.root_directory / "models/realesrgan").exists():
|
||||
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
||||
for d in ["codeformer", "gfpgan"]:
|
||||
path = self.root_directory / "models" / d
|
||||
if path.exists():
|
||||
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
||||
|
||||
def migrate_tuning_models(self):
|
||||
"""
|
||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||
"""
|
||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||
if not src:
|
||||
continue
|
||||
if src.is_dir():
|
||||
logger.info(f"Scanning {src}")
|
||||
self.migrate_models(src)
|
||||
else:
|
||||
logger.info(f"{src} directory not found; skipping")
|
||||
continue
|
||||
|
||||
def migrate_conversion_models(self):
|
||||
"""
|
||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||
script.
|
||||
"""
|
||||
|
||||
dest_directory = self.dest_models
|
||||
kwargs = {
|
||||
"cache_dir": self.root_directory / "models/hub",
|
||||
# local_files_only = True
|
||||
}
|
||||
try:
|
||||
logger.info("Migrating core tokenizers and text encoders")
|
||||
target_dir = dest_directory / "core" / "convert"
|
||||
|
||||
self._migrate_pretrained(
|
||||
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
||||
)
|
||||
|
||||
# sd-1
|
||||
repo_id = "openai/clip-vit-large-patch14"
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
||||
)
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
||||
**{"subfolder": "tokenizer", **kwargs},
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
||||
**{"subfolder": "text_encoder", **kwargs},
|
||||
)
|
||||
|
||||
# VAE
|
||||
logger.info("Migrating stable diffusion VAE")
|
||||
self._migrate_pretrained(
|
||||
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
||||
)
|
||||
|
||||
# safety checking
|
||||
logger.info("Migrating safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self._migrate_pretrained(
|
||||
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
StableDiffusionSafetyChecker,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-safety-checker",
|
||||
**kwargs,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
||||
if dest.exists() and not force:
|
||||
logger.info(f"Skipping existing {dest}")
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest, overwrite=force)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
||||
model_name = dest.name
|
||||
if overwrite:
|
||||
model.save_pretrained(dest, safe_serialization=True)
|
||||
else:
|
||||
download_path = dest.with_name(f"{model_name}.downloading")
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split("/")
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||
"""
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
"""
|
||||
vae_path = None
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae, str):
|
||||
vae_path = vae
|
||||
|
||||
elif isinstance(vae, DictConfig):
|
||||
if p := vae.get("path"):
|
||||
vae_path = p
|
||||
elif repo_id := vae.get("repo_id"):
|
||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||
return vae_path
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
|
||||
# if the VAE is in the old models directory, then we must move it into the new
|
||||
# one. VAEs outside of this directory can stay where they are.
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.is_relative_to(self.src_paths.models):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path, dest)
|
||||
else:
|
||||
self.copy_file(vae_path, dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path("models", rel_path)
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
"""
|
||||
dest_dir = self.dest_models
|
||||
|
||||
cache = self.root_directory / "models/hub"
|
||||
kwargs = {
|
||||
"cache_dir": cache,
|
||||
"safety_checker": None,
|
||||
# local_files_only = True,
|
||||
}
|
||||
|
||||
owner, repo_name = repo_id.split("/")
|
||||
model_name = model_name or repo_name
|
||||
model = cache / "--".join(["models", owner, repo_name])
|
||||
|
||||
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
||||
return
|
||||
revisions = [x.name for x in model.glob("refs/*")]
|
||||
|
||||
# if an fp16 is available we use that
|
||||
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
||||
|
||||
info = ModelProbe().heuristic_probe(pipeline)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / model_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path("models", dest.relative_to(dest_dir))
|
||||
self._add_model(model_name, info, rel_path, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a model referred to using 'weights' or 'path'
|
||||
"""
|
||||
|
||||
# handle relative paths
|
||||
dest_dir = self.dest_models
|
||||
location = self.root_directory / location
|
||||
model_name = model_name or location.stem
|
||||
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
# uh oh, weights is in the old models directory - move it into the new one
|
||||
if Path(location).is_relative_to(self.src_paths.models):
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||
if location.is_dir():
|
||||
self.copy_dir(location, dest)
|
||||
else:
|
||||
self.copy_file(location, dest)
|
||||
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
||||
|
||||
self._add_model(model_name, info, location, **extra_config)
|
||||
|
||||
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
||||
if info.model_type != ModelType.Main:
|
||||
return
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
clobber=True,
|
||||
model_attributes={
|
||||
"path": str(location),
|
||||
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
||||
"model_format": info.format,
|
||||
"variant": info.variant_type.value,
|
||||
**extra_config,
|
||||
},
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
"""
|
||||
Migrate models defined in models.yaml
|
||||
"""
|
||||
# find any models referred to in old models.yaml
|
||||
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
||||
|
||||
for model_name, stanza in conf.items():
|
||||
try:
|
||||
passthru_args = {}
|
||||
|
||||
if vae := stanza.get("vae"):
|
||||
try:
|
||||
passthru_args["vae"] = str(self._vae_path(vae))
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||
logger.warning(str(e))
|
||||
|
||||
if config := stanza.get("config"):
|
||||
passthru_args["config"] = config
|
||||
|
||||
if description := stanza.get("description"):
|
||||
passthru_args["description"] = description
|
||||
|
||||
if repo_id := stanza.get("repo_id"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get("weights"):
|
||||
logger.info(f"Migrating checkpoint model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get("path"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def migrate(self):
|
||||
self.create_directory_structure()
|
||||
# the configure script is doing this
|
||||
self.migrate_support_models()
|
||||
self.migrate_conversion_models()
|
||||
self.migrate_tuning_models()
|
||||
self.migrate_defined_models()
|
||||
|
||||
|
||||
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
||||
parser.add_argument(
|
||||
"--embedding_directory",
|
||||
"--embedding_path",
|
||||
type=Path,
|
||||
dest="embedding_path",
|
||||
default=Path("embeddings"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_directory",
|
||||
dest="lora_path",
|
||||
type=Path,
|
||||
default=Path("loras"),
|
||||
)
|
||||
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
||||
return ModelPaths(
|
||||
models=root / "models",
|
||||
embeddings=root / str(opt.embedding_path).strip('"'),
|
||||
loras=root / str(opt.lora_path).strip('"'),
|
||||
controlnets=root / "controlnets",
|
||||
)
|
||||
|
||||
|
||||
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
"""
|
||||
# Don't use the config object because it is unforgiving of version updates
|
||||
# Just use omegaconf directly
|
||||
opt = OmegaConf.load(initfile)
|
||||
paths = opt.InvokeAI.Paths
|
||||
models = paths.get("models_dir", "models")
|
||||
embeddings = paths.get("embedding_dir", "embeddings")
|
||||
loras = paths.get("lora_dir", "loras")
|
||||
controlnets = paths.get("controlnet_dir", "controlnets")
|
||||
return ModelPaths(
|
||||
models=root / models if models else None,
|
||||
embeddings=root / embeddings if embeddings else None,
|
||||
loras=root / loras if loras else None,
|
||||
controlnets=root / controlnets if controlnets else None,
|
||||
)
|
||||
|
||||
|
||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
path = root / "invokeai.init"
|
||||
if path.exists():
|
||||
return _parse_legacy_initfile(root, path)
|
||||
path = root / "invokeai.yaml"
|
||||
if path.exists():
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||
dest_models = dest_directory / "models.3"
|
||||
|
||||
version_3 = (dest_directory / "models" / "core").exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||
|
||||
config_file.replace(config_file.with_suffix(""))
|
||||
dest_models.replace(dest_models.with_suffix(""))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="invokeai-migrate3",
|
||||
description="""
|
||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||
|
||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||
script, which will perform a full upgrade in place.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-directory",
|
||||
dest="src_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to-directory",
|
||||
dest="dest_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
src_root = args.src_root
|
||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / "invokeai.init").exists() or (
|
||||
src_root / "invokeai.yaml"
|
||||
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
|
||||
dest_root = args.dest_root
|
||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(["--root", str(dest_root)])
|
||||
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||
if not dest_is_setup:
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root, dest_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,637 +0,0 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import HfApi, HfFolder, hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(name="InvokeAI")
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
LEGACY_CONFIGS = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo:
|
||||
name: str
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
path: Optional[Path] = None
|
||||
repo_id: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: str = ""
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
default: bool = False
|
||||
requires: Optional[List[str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
civitai_api_key: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
self.prediction_helper = prediction_type_helper
|
||||
self.access_token = access_token or HfFolder.get_token()
|
||||
self.civitai_api_key = civitai_api_key or config.civitai_api_key
|
||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||
|
||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
||||
"""
|
||||
Return dict of model_key=>ModelLoadInfo objects.
|
||||
This method consolidates and simplifies the entries in both
|
||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
"""
|
||||
model_dict = {}
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
value["name"] = name
|
||||
value["base_type"] = base
|
||||
value["model_type"] = model_type
|
||||
model_info = ModelLoadInfo(**value)
|
||||
if model_info.subfolder and model_info.repo_id:
|
||||
model_info.repo_id += f":{model_info.subfolder}"
|
||||
model_dict[key] = model_info
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = list(self.mgr.list_models())
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
model_type = md["model_type"]
|
||||
name = md["model_name"]
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
else:
|
||||
model_dict[key] = ModelLoadInfo(
|
||||
name=name,
|
||||
base_type=base,
|
||||
model_type=model_type,
|
||||
path=value.get("path"),
|
||||
installed=True,
|
||||
)
|
||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
||||
|
||||
def _is_autoloaded(self, model_info: dict) -> bool:
|
||||
path = model_info.get("path")
|
||||
if not path:
|
||||
return False
|
||||
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
||||
if autodir_path := getattr(self.config, autodir):
|
||||
autodir_path = self.config.root_path / autodir_path
|
||||
if Path(path).is_relative_to(autodir_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_models(self, model_type):
|
||||
installed = self.mgr.list_models(model_type=model_type)
|
||||
print()
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
print(f"{'Model Key':50} Model Path")
|
||||
for i in installed:
|
||||
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
|
||||
print()
|
||||
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, _value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
return models
|
||||
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
# remove requested models
|
||||
for key in selections.remove_models:
|
||||
name, base, mtype = self.mgr.parse_key(key)
|
||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
||||
try:
|
||||
self.mgr.del_model(name, base, mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
self._remove_installed(selections.install_models)
|
||||
self._add_required_models(selections.install_models)
|
||||
for path in selections.install_models:
|
||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
models_installed: Set[Path] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = {}
|
||||
|
||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
|
||||
# fix relative paths
|
||||
if path.exists() and not path.is_absolute():
|
||||
path = path.absolute() # make relative to current WD
|
||||
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path): self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
||||
|
||||
return models_installed
|
||||
|
||||
def _remove_installed(self, model_list: List[str]):
|
||||
all_models = self.all_models()
|
||||
models_to_remove = []
|
||||
|
||||
for path in model_list:
|
||||
key = self.reverse_paths.get(path)
|
||||
if key and all_models[key].installed:
|
||||
models_to_remove.append(path)
|
||||
|
||||
for path in models_to_remove:
|
||||
logger.warning(f"{path} already installed. Skipping")
|
||||
model_list.remove(path)
|
||||
|
||||
def _add_required_models(self, model_list: List[str]):
|
||||
additional_models = []
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
if not (key := self.reverse_paths.get(path)):
|
||||
continue
|
||||
for requirement in all_models[key].requires:
|
||||
requirement_key = self.reverse_paths.get(requirement)
|
||||
if not all_models[requirement_key].installed:
|
||||
additional_models.append(requirement)
|
||||
model_list.extend(additional_models)
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Unable to parse format of {path}")
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path, info)
|
||||
return self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_attributes=attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str) -> AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
CIVITAI_RE = r".*civitai.com.*"
|
||||
civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE)
|
||||
location = download_with_resume(
|
||||
url, Path(staging), access_token=self.civitai_api_key if civit_url else None
|
||||
)
|
||||
if not location:
|
||||
logger.error(f"Unable to download {url}. Skipping.")
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
models_path = shutil.move(location, dest)
|
||||
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||
# hack to recover models stored in subfolders --
|
||||
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
|
||||
subfolder = None
|
||||
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
|
||||
repo_id = match.group(1)
|
||||
subfolder = match.group(2)
|
||||
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
# list all the files in the repo
|
||||
files = [x.rfilename for x in hinfo.siblings]
|
||||
if subfolder:
|
||||
files = [x for x in files if x.startswith(f"{subfolder}/")]
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
|
||||
location = None
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if f"{prefix}model_index.json" in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
|
||||
elif f"{prefix}unet/model.onnx" in files:
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
else:
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
|
||||
) # LoRA
|
||||
break
|
||||
elif (
|
||||
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
|
||||
): # vae, controlnet or some other standalone
|
||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
|
||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
|
||||
)
|
||||
break
|
||||
elif (
|
||||
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
|
||||
): # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
||||
return {}
|
||||
dest = (
|
||||
self.config.models_path
|
||||
/ info.base_type.value
|
||||
/ info.model_type.value
|
||||
/ self._get_model_name(repo_id, location)
|
||||
)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(location, dest)
|
||||
return self._install_path(dest, info)
|
||||
|
||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
||||
"""
|
||||
Calculate a name for the model - primitive implementation.
|
||||
"""
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
elif location.is_dir():
|
||||
return location.name
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
||||
model_name = path.name if path.is_dir() else path.stem
|
||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
||||
if key := self.reverse_paths.get(self.current_id):
|
||||
if key in self.datasets:
|
||||
description = self.datasets[key].get("description") or description
|
||||
|
||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||
|
||||
attributes = {
|
||||
"path": str(rel_path),
|
||||
"description": str(description),
|
||||
"model_format": info.format,
|
||||
}
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
{
|
||||
"variant": info.variant_type,
|
||||
}
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir,
|
||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||
)
|
||||
else:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||
)
|
||||
except KeyError:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
||||
|
||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
else:
|
||||
legacy_conf = Path(
|
||||
self.config.root_path,
|
||||
"configs/controlnet",
|
||||
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
|
||||
)
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update({"config": str(legacy_conf)})
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||
root = root or self.config.root_path
|
||||
if path.is_relative_to(root):
|
||||
return path.relative_to(root)
|
||||
else:
|
||||
return path
|
||||
|
||||
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
|
||||
"""
|
||||
Retrieve a StableDiffusion model from cache or remote and then
|
||||
does a save_pretrained() to the indicated staging area.
|
||||
"""
|
||||
_, name = repo_id.split("/")
|
||||
precision = torch_dtype(choose_torch_device())
|
||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||
|
||||
model = None
|
||||
for variant in variants:
|
||||
try:
|
||||
model = DiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
variant=variant,
|
||||
torch_dtype=precision,
|
||||
safety_checker=None,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
if "fp16" not in str(e):
|
||||
print(e)
|
||||
|
||||
if model:
|
||||
break
|
||||
|
||||
if not model:
|
||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||
return None
|
||||
model.save_pretrained(staging / name, safe_serialization=True)
|
||||
return staging / name
|
||||
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = []
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
repo_id,
|
||||
model_dir=location / filePath.parent,
|
||||
model_name=filePath.name,
|
||||
access_token=self.access_token,
|
||||
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
else:
|
||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
||||
|
||||
return location if len(paths) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _reverse_paths(cls, datasets) -> dict:
|
||||
"""
|
||||
Reverse mapping from repo_id/path to destination name.
|
||||
"""
|
||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger = InvokeAILogger.get_logger("InvokeAI")
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
subfolder: str = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with (
|
||||
open(model_dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
@@ -8,8 +8,8 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
class IPAdapter(RawModel):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
@@ -124,6 +124,9 @@ class IPAdapter:
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
|
||||
@@ -1,98 +1,17 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
for ext in ["safetensors", "bin"]:
|
||||
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
||||
path = Path(model_path, f"pytorch_lora_weights.{ext}")
|
||||
if path.exists():
|
||||
return path
|
||||
else:
|
||||
return model_path
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@@ -108,7 +27,7 @@ class LoRALayerBase:
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
@@ -116,7 +35,7 @@ class LoRALayerBase:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
@@ -128,7 +47,7 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -142,7 +61,7 @@ class LoRALayerBase:
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -156,20 +75,20 @@ class LoRALayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@@ -190,7 +109,7 @@ class LoRALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
@@ -208,11 +127,7 @@ class LoHALayer(LoRALayerBase):
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
@@ -221,20 +136,20 @@ class LoHALayer(LoRALayerBase):
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
@@ -254,7 +169,7 @@ class LoHALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
@@ -280,12 +195,12 @@ class LoKRLayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
@@ -294,7 +209,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
@@ -303,7 +218,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
@@ -314,14 +229,18 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
w1 = self.w1
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
@@ -329,6 +248,8 @@ class LoKRLayer(LoRALayerBase):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
@@ -344,18 +265,22 @@ class LoKRLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -369,7 +294,7 @@ class FullLayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
@@ -382,7 +307,7 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -394,7 +319,7 @@ class FullLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -407,7 +332,7 @@ class IA3Layer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
@@ -416,10 +341,11 @@ class IA3Layer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -439,28 +365,30 @@ class IA3Layer(LoRALayerBase):
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
@@ -472,7 +400,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
@@ -536,7 +464,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
):
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
@@ -544,16 +472,16 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
@@ -561,7 +489,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
@@ -592,8 +520,8 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = {}
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
@@ -606,7 +534,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Model Cache
|
||||
|
||||
## `glibc` Memory Allocator Fragmentation
|
||||
|
||||
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
|
||||
|
||||
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
|
||||
|
||||
To better understand how the `glibc` memory allocator works, see these references:
|
||||
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
|
||||
- Details: https://sourceware.org/glibc/wiki/MallocInternals
|
||||
|
||||
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
|
||||
|
||||
We can work around this memory fragmentation issue by setting the following env var:
|
||||
|
||||
```bash
|
||||
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
|
||||
MALLOC_MMAP_THRESHOLD_=1048576
|
||||
```
|
||||
|
||||
See the following references for more information about the `malloc` tunable parameters:
|
||||
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
|
||||
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
|
||||
- https://man7.org/linux/man-pages/man3/mallopt.3.html
|
||||
|
||||
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.
|
||||
@@ -1,20 +0,0 @@
|
||||
# ruff: noqa: I001, F401
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
# This import must be first
|
||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .model_cache import ModelCache
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
# This import must be last
|
||||
from .model_merge import MergeInterpolationMethod, ModelMerger
|
||||
@@ -1,31 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
This module exports the function has_baked_in_sdxl_vae().
|
||||
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
|
||||
which doesn't work properly in fp16 mode.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
|
||||
|
||||
|
||||
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
|
||||
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
|
||||
hash = _vae_hash(checkpoint_path)
|
||||
return hash != SDXL_1_0_VAE_HASH
|
||||
|
||||
|
||||
def _vae_hash(checkpoint_path: Path) -> str:
|
||||
checkpoint = load_file(checkpoint_path, device="cpu")
|
||||
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
|
||||
hash = hashlib.new("sha256")
|
||||
for key in vae_keys:
|
||||
value = checkpoint[key]
|
||||
hash.update(bytes(key, "UTF-8"))
|
||||
hash.update(bytes(str(value), "UTF-8"))
|
||||
|
||||
return hash.hexdigest()
|
||||
@@ -1,553 +0,0 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union, types
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||
|
||||
from ..util.devices import choose_torch_device
|
||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
# {submodel_key => size}
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.cache = cache
|
||||
self._locks = 0
|
||||
|
||||
def lock(self):
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self):
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
return self._locks > 0
|
||||
|
||||
@property
|
||||
def loaded(self):
|
||||
if self.model is not None and hasattr(self.model, "device"):
|
||||
return self.model.device != self.cache.storage_device
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger,
|
||||
log_memory_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = {}
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype = precision
|
||||
self.max_cache_size: float = max_cache_size
|
||||
self.max_vram_cache_size: float = max_vram_cache_size
|
||||
self.execution_device: torch.device = execution_device
|
||||
self.storage_device: torch.device = storage_device
|
||||
self.sha_chunksize = sha_chunksize
|
||||
self.logger = logger
|
||||
self._log_memory_usage = log_memory_usage
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = {}
|
||||
self._cache_stack = []
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
return self.model_infos[model_info_key]
|
||||
|
||||
# TODO: args
|
||||
def get_model(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
gpu_load: bool = True,
|
||||
) -> Any:
|
||||
if not isinstance(model_path, Path):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
model_info = self._get_model_info(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type"
|
||||
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
|
||||
self_reported_model_size_before_load = model_info.get_size(submodel)
|
||||
# Remove old models from the cache to make room for the new model.
|
||||
self._make_cache_room(self_reported_model_size_before_load)
|
||||
|
||||
# Load the model from disk and capture a memory snapshot before/after.
|
||||
start_load_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
with skip_torch_weight_init():
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_load_time = time.time()
|
||||
|
||||
self_reported_model_size_after_load = model_info.get_size(submodel)
|
||||
|
||||
self.logger.debug(
|
||||
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
|
||||
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
|
||||
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
|
||||
self.logger.debug(
|
||||
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
|
||||
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
|
||||
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
|
||||
)
|
||||
|
||||
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
|
||||
self._cached_models[key] = cache_entry
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
def _move_model_to_device(self, key: str, target_device: torch.device):
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
source_device = cache_entry.model.device
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
||||
# multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
"""
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
:param gpu_load: True if load into gpu
|
||||
:param size_needed: Size of the model to load
|
||||
"""
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.size_needed = size_needed
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
if self.gpu_load:
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models(self.size_needed)
|
||||
|
||||
self.cache._move_model_to_device(self.key, self.cache.execution_device)
|
||||
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except Exception:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
# TODO: not fully understand
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||
self.cache._move_model_to_device(self.key, self.cache.storage_device)
|
||||
|
||||
return self.model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
with suppress(ValueError):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == "cuda"
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % self.cache_size()
|
||||
|
||||
cached_models = 0
|
||||
loaded_models = 0
|
||||
locked_models = 0
|
||||
for model_info in self._cached_models.values():
|
||||
cached_models += 1
|
||||
if model_info.loaded:
|
||||
loaded_models += 1
|
||||
if model_info.locked:
|
||||
locked_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
|
||||
f" {cached_models}/{loaded_models}/{locked_models}"
|
||||
)
|
||||
|
||||
def _cache_size(self) -> int:
|
||||
return sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self._cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
if self.stats:
|
||||
self.stats.cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _offload_unlocked_models(self, size_needed: int = 0):
|
||||
reserved = self.max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self._move_model_to_device(model_key, self.storage_device)
|
||||
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug(f"computing hash of model {path.name}")
|
||||
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,140 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
WeightedSum = "weighted_sum"
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
def merge_diffusion_models_and_save(
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType, str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = []
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert (
|
||||
info["model_format"] == "diffusers"
|
||||
), f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
assert (
|
||||
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([(config.root_path / info["path"]).as_posix()])
|
||||
|
||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
||||
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
|
||||
dump_path = (
|
||||
Path(merge_dest_directory)
|
||||
if merge_dest_directory
|
||||
else config.models_path / base_model.value / ModelType.Main.value
|
||||
)
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = (dump_path / merged_model_name).as_posix()
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = {
|
||||
"path": dump_path,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"model_format": "diffusers",
|
||||
"variant": ModelVariantType.Normal.value,
|
||||
"vae": vae,
|
||||
}
|
||||
return self.manager.add_model(
|
||||
merged_model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
model_attributes=attributes,
|
||||
clobber=True,
|
||||
)
|
||||
@@ -1,664 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""forward declaration"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
) -> ModelProbeInfo:
|
||||
if isinstance(model, Path):
|
||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the SchedulerPredictionType.
|
||||
"""
|
||||
if model_path:
|
||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
else:
|
||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
||||
model_info = None
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model_path, model)
|
||||
if format_type == "diffusers"
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
name = cls.get_model_name(model_path)
|
||||
description = f"{base_type.value} {model_type.value} model {name}"
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
name=name,
|
||||
description=description,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
return None
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
||||
"""
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
for suffix in ["bin", "safetensors"]:
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(
|
||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
||||
) -> BaseModelType:
|
||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "checkpoint"
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Return model prediction type."""
|
||||
# if there is a .yaml associated with this checkpoint, then we do not need
|
||||
# to probe for the prediction type as it will be ignored.
|
||||
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
||||
return None
|
||||
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion2:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||
|
||||
elif type == BaseModelType.StableDiffusion1:
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "lycoris"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[-1]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_dim == 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "diffusers"
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
else:
|
||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.folder_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
return None
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
adapter_type = config.get("adapter_type", None)
|
||||
if adapter_type == "full_adapter_xl":
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif adapter_type == "full_adapter" or "light_adapter":
|
||||
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||
return BaseModelType.StableDiffusion1
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
|
||||
)
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Set, types
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return list(self.models_found)
|
||||
@@ -1,167 +0,0 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing import Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, create_model
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelError,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .clip_vision import CLIPVisionModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .ip_adapter import IPAdapterModel
|
||||
from .lora import LoRAModel
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
from .t2i_adapter import T2IAdapterModel
|
||||
from .textual_inversion import TextualInversionModel
|
||||
from .vae import VaeModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||
ModelType.Main: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
ModelType.T2IAdapter: T2IAdapterModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
ModelType.T2IAdapter: T2IAdapterModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
# will not work until support written
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
ModelType.T2IAdapter: T2IAdapterModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
# will not work until support written
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
ModelType.T2IAdapter: T2IAdapterModel,
|
||||
},
|
||||
BaseModelType.Any: {
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
# The following model types are not expected to be used with BaseModelType.Any.
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.T2IAdapter: T2IAdapterModel,
|
||||
},
|
||||
# BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
# ModelType.Lora: LoRAModel,
|
||||
# ModelType.ControlNet: ControlNetModel,
|
||||
# ModelType.TextualInversion: TextualInversionModel,
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = []
|
||||
OPENAPI_MODEL_CONFIGS = []
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
for _base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
# LS: sort to get the checkpoint configs first, which makes
|
||||
# for a better template in the Swagger docs
|
||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = create_model(
|
||||
openapi_cfg_name,
|
||||
__base__=(cfg, OpenAPIModelInfoBase),
|
||||
model_type=(Literal[model_type], model_type), # type: ignore
|
||||
)
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = []
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
fields = inspect.get_annotations(model_config)
|
||||
else:
|
||||
fields = model_config.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except Exception:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
# model_format: SomeModelFormat
|
||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(
|
||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||
):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
@@ -1,681 +0,0 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, DiffusionPipeline
|
||||
from diffusers import logging as diffusers_logging
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
Any = "any" # For models that are not associated with any particular base model.
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
# MoVQ = "movq"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
schema["required"].append("model_type")
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
|
||||
)
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class classproperty(Generic[T_co]):
|
||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||
self.fget = fget
|
||||
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError("cannot set attribute")
|
||||
|
||||
|
||||
class ModelBase(metaclass=ABCMeta):
|
||||
# model_path: str
|
||||
# base_model: BaseModelType
|
||||
# model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.base_model = base_model
|
||||
self.model_type = model_type
|
||||
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
return None
|
||||
elif any(t is None for t in subtypes):
|
||||
raise Exception(f"Unsupported definition: {subtypes}")
|
||||
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
|
||||
else:
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = res_type.pipelines
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
return res_type
|
||||
|
||||
@classmethod
|
||||
def _get_configs(cls):
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
configs = {}
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
|
||||
value = getattr(cls, name)
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
fields = inspect.get_annotations(value)
|
||||
else:
|
||||
fields = value.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except Exception:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
for model_format in field:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif typing.get_origin(field) is Literal and all(
|
||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||
):
|
||||
for model_format in field.__args__:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif field is None:
|
||||
configs[None] = value
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||
|
||||
cls.__configs = configs
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "model_format" not in kwargs:
|
||||
raise Exception("Field 'model_format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["model_format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classproperty
|
||||
@abstractmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
# child_types: Dict[str, Type]
|
||||
# child_sizes: Dict[str, int]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.child_types: Dict[str, Type] = {}
|
||||
self.child_sizes: Dict[str, int] = {}
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except Exception:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
config_data.pop("_ignore_files", None)
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
||||
|
||||
for child_name in child_components:
|
||||
child_type = self._hf_definition_to_type(config_data[child_name])
|
||||
self.child_types[child_name] = child_type
|
||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is None:
|
||||
return sum(self.child_sizes.values())
|
||||
else:
|
||||
return self.child_sizes[child_type]
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
# return pipeline in different function to pass more arguments
|
||||
if child_type is None:
|
||||
raise Exception("Child model type can't be null on diffusers model")
|
||||
if child_type not in self.child_types:
|
||||
return None # TODO: or raise
|
||||
|
||||
if torch_dtype == torch.float16:
|
||||
variants = ["fp16", None]
|
||||
else:
|
||||
variants = [None, "fp16"]
|
||||
|
||||
# TODO: better error handling(differentiate not found from others)
|
||||
for variant in variants:
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if not str(e).startswith("Error no file"):
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
|
||||
# calc more accurate size
|
||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
|
||||
|
||||
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
|
||||
# this can happen when, for example, the safety checker
|
||||
# is not downloaded.
|
||||
if not os.path.exists(model_path):
|
||||
return 0
|
||||
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
||||
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
elif variant == "8bit":
|
||||
files = bit8_files
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||
|
||||
# try read from index if exists
|
||||
index_postfix = ".index.json"
|
||||
if variant is not None:
|
||||
index_postfix = f".index.{variant}.json"
|
||||
|
||||
for file in files:
|
||||
if not file.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
model_files = [f for f in files if f.endswith(file_format)]
|
||||
if len(model_files) == 0:
|
||||
continue
|
||||
|
||||
model_size = 0
|
||||
for model_file in model_files:
|
||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
|
||||
def calc_model_size_by_data(model) -> int:
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _calc_pipeline_by_data(pipeline) -> int:
|
||||
res = 0
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model) -> int:
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _calc_onnx_model_by_data(model) -> int:
|
||||
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||
mem = tensor_size # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||
"pt",
|
||||
"torch",
|
||||
"pytorch",
|
||||
}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = {}
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||
return numpy_helper.to_array(value)
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray):
|
||||
new_node = numpy_helper.from_array(value)
|
||||
# set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
# new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return self.indexes[key] in self.model.proto.graph.initializer
|
||||
|
||||
def items(self):
|
||||
raise NotImplementedError("tensor.items")
|
||||
# return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
raise NotImplementedError("tensor.values")
|
||||
# return [obj for obj in self.raw_proto]
|
||||
|
||||
def size(self):
|
||||
bytesSum = 0
|
||||
for node in self.model.proto.graph.initializer:
|
||||
bytesSum += sys.getsizeof(node.raw_data)
|
||||
return bytesSum
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = {}
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.indexes
|
||||
|
||||
def items(self):
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return list(self.raw_proto)
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
# self.data = dict()
|
||||
# for tensor in self.proto.graph.initializer:
|
||||
# name = tensor.name
|
||||
|
||||
# if tensor.HasField("raw_data"):
|
||||
# npt = numpy_helper.to_array(tensor)
|
||||
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
# # self.data[name] = orv
|
||||
# # set_external_data(tensor, location="in-memory-location")
|
||||
# tensor.name = name
|
||||
# # tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self, height=None, width=None):
|
||||
if self.session is None or self.session_width != width or self.session_height != height:
|
||||
# onnx.save(self.proto, "tmp.onnx")
|
||||
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
# self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
# sess.enable_profiling = True
|
||||
|
||||
# sess.intra_op_num_threads = 1
|
||||
# sess.inter_op_num_threads = 1
|
||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.enable_cpu_mem_arena = True
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
self.session_height = height
|
||||
self.session_width = width
|
||||
if height and width:
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
providers = []
|
||||
if self.provider:
|
||||
providers.append(self.provider)
|
||||
else:
|
||||
providers = get_available_providers()
|
||||
if "TensorrtExecutionProvider" in providers:
|
||||
providers.remove("TensorrtExecutionProvider")
|
||||
try:
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
# self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
# output_names = self.session.get_outputs()
|
||||
# for k in inputs:
|
||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
# for name in output_names:
|
||||
# self.io_binding.bind_output(name.name)
|
||||
# self.session.run_with_iobinding(self.io_binding, None)
|
||||
# return self.io_binding.copy_outputs_to_cpu()
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
subfolder: Union[str, Path] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
@@ -1,82 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class CLIPVisionModel(ModelBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[CLIPVisionModelFormat.Diffusers]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.CLIPVision
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||
return CLIPVisionModelFormat.Diffusers
|
||||
|
||||
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> CLIPVisionModelWithProjection:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
|
||||
|
||||
# Calculate a more accurate model size.
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == CLIPVisionModelFormat.Diffusers:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
@@ -1,163 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
# model_class: Type
|
||||
# model_size: int
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Checkpoint]
|
||||
config: str
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except Exception:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
if model_class_name not in {"ControlNetModel"}:
|
||||
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
|
||||
|
||||
try:
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except Exception:
|
||||
raise Exception("Invalid ControlNet model!")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There are no child models in controlnet model")
|
||||
|
||||
model = None
|
||||
for variant in ["fp16", None]:
|
||||
try:
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||
return _convert_controlnet_ckpt_and_cache(
|
||||
model_path=model_path,
|
||||
model_config=config.config,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
def _convert_controlnet_ckpt_and_cache(
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: str,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the controlnet from checkpoint format to diffusers format,
|
||||
cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
print(f"DEBUG: controlnet config = {model_config}")
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights = app_config.root_path / model_path
|
||||
output_path = Path(output_path)
|
||||
|
||||
logger.info(f"Converting {weights} to diffusers format")
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
weights,
|
||||
output_path,
|
||||
original_config_file=app_config.root_path / model_config,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=weights.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
@@ -1,98 +0,0 @@
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterModelFormat(str, Enum):
|
||||
# The custom IP-Adapter model format defined by InvokeAI.
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class IPAdapterModel(ModelBase):
|
||||
class InvokeAIConfig(ModelConfigBase):
|
||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path):
|
||||
model_file = os.path.join(path, "ip_adapter.bin")
|
||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||
return IPAdapterModelFormat.InvokeAI
|
||||
|
||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: torch.dtype,
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == IPAdapterModelFormat.InvokeAI:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
|
||||
|
||||
def get_ip_adapter_image_encoder_model_id(model_path: str):
|
||||
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
|
||||
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
|
||||
|
||||
with open(image_encoder_config_file, "r") as f:
|
||||
image_encoder_model = f.readline().strip()
|
||||
|
||||
return image_encoder_model
|
||||
@@ -1,148 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
classproperty,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionXLModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
||||
assert model_type == ModelType.Main
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusionXL,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
# avoid circular import
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusionXLModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusionXLModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
# The convert script adapted from the diffusers package uses
|
||||
# strings for the base model type. To avoid making too many
|
||||
# source code changes, we simply translate here
|
||||
if Path(output_path).exists():
|
||||
return output_path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||
|
||||
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
|
||||
# then we bake it into the converted model unless there is already
|
||||
# a nonstandard VAE installed.
|
||||
kwargs = {}
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
|
||||
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
|
||||
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
|
||||
kwargs["vae_path"] = vae_path
|
||||
|
||||
return _convert_ckpt_and_cache(
|
||||
version=base_model,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
use_safetensors=True,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
@@ -1,337 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SilenceWarnings,
|
||||
classproperty,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
|
||||
|
||||
class StableDiffusion1ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Main
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
load_safety_checker=False,
|
||||
output_path=output_path,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Main
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
# pass precision - currently defaulting to fp16
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[
|
||||
StableDiffusion1Model.CheckpointConfig,
|
||||
StableDiffusion2Model.CheckpointConfig,
|
||||
StableDiffusionXLModel.CheckpointConfig,
|
||||
],
|
||||
output_path: str,
|
||||
use_save_model: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
weights = app_config.models_path / model_config.path
|
||||
config_file = app_config.root_path / model_config.config
|
||||
output_path = Path(output_path)
|
||||
variant = model_config.variant
|
||||
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# to avoid circular import errors
|
||||
from ...util.devices import choose_torch_device, torch_dtype
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
logger.info(f"Converting {weights} to diffusers format")
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
output_path,
|
||||
model_type=model_base_to_model_type[version],
|
||||
model_version=version,
|
||||
model_variant=model_config.variant,
|
||||
original_config_file=config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=weights.suffix == ".safetensors",
|
||||
precision=torch_dtype(choose_torch_device()),
|
||||
**kwargs,
|
||||
)
|
||||
return output_path
|
||||
|
||||
|
||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ckpt_configs = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
ModelVariantType.Inpaint: None,
|
||||
ModelVariantType.Depth: None,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
ModelVariantType.Inpaint: None,
|
||||
ModelVariantType.Depth: None,
|
||||
},
|
||||
}
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
try:
|
||||
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
||||
if config_path.is_relative_to(app_config.root_path):
|
||||
config_path = config_path.relative_to(app_config.root_path)
|
||||
return str(config_path)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1,150 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
Olive = "olive"
|
||||
Onnx = "onnx"
|
||||
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
class Config(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
# TODO: Detect onnx vs olive
|
||||
return StableDiffusionOnnxModelFormat.Onnx
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
# TODO: Detect onnx vs olive
|
||||
return StableDiffusionOnnxModelFormat.Onnx
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
@@ -1,102 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class T2IAdapterModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class T2IAdapterModel(ModelBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[T2IAdapterModelFormat.Diffusers]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.T2IAdapter
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
if model_class_name not in {"T2IAdapter"}:
|
||||
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
|
||||
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> T2IAdapter:
|
||||
if child_type is not None:
|
||||
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
||||
|
||||
model = None
|
||||
for variant in ["fp16", None]:
|
||||
try:
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
||||
# Calculate a more accurate size after loading the model into memory.
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException(f"Model not found at '{path}'.")
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return T2IAdapterModelFormat.Diffusers
|
||||
|
||||
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == T2IAdapterModelFormat.Diffusers:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
@@ -1,87 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class TextualInversionModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
|
||||
checkpoint_path = self.model_path
|
||||
if os.path.isdir(checkpoint_path):
|
||||
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=checkpoint_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
@@ -1,179 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
# vae_class: Type
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: VaeModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
try:
|
||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
|
||||
model = self.vae_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
model_config=config,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
def _convert_vae_ckpt_and_cache(
|
||||
weights_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ModelConfigBase,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights_path = app_config.root_dir / weights_path
|
||||
output_path = Path(output_path)
|
||||
|
||||
"""
|
||||
this size used only in when tiling enabled to separate input in tiles
|
||||
sizes in configs from stable diffusion githubs(1 and 2) set to 256
|
||||
on huggingface it:
|
||||
1.5 - 512
|
||||
1.5-inpainting - 256
|
||||
2-inpainting - 512
|
||||
2-depth - 256
|
||||
2-base - 512
|
||||
2 - 768
|
||||
2.1-base - 768
|
||||
2.1 - 768
|
||||
"""
|
||||
image_size = 512
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
# all sd models use same vae settings
|
||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||
else:
|
||||
raise Exception(f"Vae conversion not supported for model type: {base_model}")
|
||||
|
||||
# this avoids circular import error
|
||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
config = OmegaConf.load(app_config.root_path / config_file)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=config,
|
||||
image_size=image_size,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
@@ -1,102 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
nn.modules.utils._pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
||||
try:
|
||||
to_restore = []
|
||||
|
||||
for m_name, m in model.named_modules():
|
||||
if isinstance(model, UNet2DConditionModel):
|
||||
if ".attentions." in m_name:
|
||||
continue
|
||||
|
||||
if ".resnets." in m_name:
|
||||
if ".conv2" in m_name:
|
||||
continue
|
||||
if ".conv_shortcut" in m_name:
|
||||
continue
|
||||
|
||||
"""
|
||||
if isinstance(model, UNet2DConditionModel):
|
||||
if False and ".upsamplers." in m_name:
|
||||
continue
|
||||
|
||||
if False and ".downsamplers." in m_name:
|
||||
continue
|
||||
|
||||
if True and ".resnets." in m_name:
|
||||
if True and ".conv1" in m_name:
|
||||
if False and "down_blocks" in m_name:
|
||||
continue
|
||||
if False and "mid_block" in m_name:
|
||||
continue
|
||||
if False and "up_blocks" in m_name:
|
||||
continue
|
||||
|
||||
if True and ".conv2" in m_name:
|
||||
continue
|
||||
|
||||
if True and ".conv_shortcut" in m_name:
|
||||
continue
|
||||
|
||||
if True and ".attentions." in m_name:
|
||||
continue
|
||||
|
||||
if False and m_name in ["conv_in", "conv_out"]:
|
||||
continue
|
||||
"""
|
||||
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
|
||||
to_restore.append((m, m._conv_forward))
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
if hasattr(module, "asymmetric_padding_mode"):
|
||||
del module.asymmetric_padding_mode
|
||||
if hasattr(module, "asymmetric_padding"):
|
||||
del module.asymmetric_padding
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@@ -12,14 +12,17 @@ from .config import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .load import LoadedModel
|
||||
from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
"InvalidModelConfigException",
|
||||
"LoadedModel",
|
||||
"ModelConfigFactory",
|
||||
"ModelFormat",
|
||||
"ModelProbe",
|
||||
@@ -29,3 +32,42 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
||||
@@ -19,12 +19,21 @@ Typical usage:
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
@@ -102,7 +111,7 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "default" # model files without "fp16" or other qualifier
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
@@ -113,11 +122,11 @@ class ModelRepoVariant(str, Enum):
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: ModelFormat
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
@@ -125,8 +134,9 @@ class ModelConfigBase(BaseModel):
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(default=None)
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
@@ -150,6 +160,7 @@ class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
@@ -199,6 +210,8 @@ class _MainConfig(ModelConfigBase):
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
@@ -212,8 +225,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
@@ -237,10 +248,21 @@ class ONNXSD2Config(_MainConfig):
|
||||
upcast_attention: bool = True
|
||||
|
||||
|
||||
class ONNXSDXLConfig(_MainConfig):
|
||||
"""Model config for ONNX format models based on sdxl."""
|
||||
|
||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
@@ -258,7 +280,7 @@ class T2IConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")]
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
@@ -271,6 +293,7 @@ AnyModelConfig = Union[
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
@@ -280,6 +303,7 @@ AnyModelConfig = Union[
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
@@ -308,9 +332,10 @@ class ModelConfigFactory(object):
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
dest_class: Optional[Type[ModelConfigBase]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
@@ -321,12 +346,17 @@ class ModelConfigFactory(object):
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
model: Optional[ModelConfigBase] = None
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.validate_python(model_data)
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
return model
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model # type: ignore
|
||||
|
||||
@@ -57,10 +57,9 @@ from transformers import (
|
||||
)
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .models import BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@@ -1643,6 +1642,8 @@ def download_controlnet_from_original_ckpt(
|
||||
cross_attention_dim: Optional[bool] = None,
|
||||
scan_needed: bool = False,
|
||||
) -> DiffusionPipeline:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
@@ -1718,6 +1719,7 @@ def convert_ckpt_to_diffusers(
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
@@ -1736,4 +1738,5 @@ def convert_controlnet_to_diffusers(
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
27
invokeai/backend/model_manager/load/__init__.py
Normal file
27
invokeai/backend/model_manager/load/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Init file for the model loader.
|
||||
"""
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||
for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
"ModelLoader",
|
||||
"ModelLoaderRegistryBase",
|
||||
"ModelLoaderRegistry",
|
||||
]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
from .convert_cache_default import ModelConvertCache
|
||||
|
||||
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]
|
||||
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Disk-based converted model cache.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ModelConvertCacheBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
pass
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Placeholder for convert cache implementation.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
|
||||
class ModelConvertCache(ModelConvertCacheBase):
|
||||
def __init__(self, cache_path: Path, max_size: float = 10.0):
|
||||
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
||||
if not cache_path.exists():
|
||||
cache_path.mkdir(parents=True)
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
size_needed = directory_size(self._cache_path) + size
|
||||
max_size = int(self.max_size) * GIG
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
if size_needed <= max_size:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
|
||||
)
|
||||
|
||||
# For this to work, we make the assumption that the directory contains
|
||||
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
|
||||
# This should be true for any diffusers model.
|
||||
def by_atime(path: Path) -> float:
|
||||
for config in ["model_index.json", "unet/config.json", "config.json"]:
|
||||
sentinel = path / config
|
||||
if sentinel.exists():
|
||||
return sentinel.stat().st_atime
|
||||
|
||||
# no sentinel file found! - pick the most recent file in the directory
|
||||
try:
|
||||
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
|
||||
return atimes[0]
|
||||
except IndexError:
|
||||
return 0.0
|
||||
|
||||
# sort by last access time - least accessed files will be at the end
|
||||
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
|
||||
logger.debug(f"cached models in descending atime order: {lru_models}")
|
||||
while size_needed > max_size and len(lru_models) > 0:
|
||||
next_victim = lru_models.pop()
|
||||
victim_size = directory_size(next_victim)
|
||||
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
|
||||
shutil.rmtree(next_victim)
|
||||
size_needed -= victim_size
|
||||
77
invokeai/backend/model_manager/load/load_base.py
Normal file
77
invokeai/backend/model_manager/load/load_base.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Base class for model loading in InvokeAI.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: AnyModelConfig
|
||||
_locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
"""Abstract base class for loading models into RAM/VRAM."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its confguration.
|
||||
|
||||
Given a model identified in the model configuration backend,
|
||||
return a ModelInfo object that can be used to retrieve the model.
|
||||
|
||||
:param model_config: Model configuration, as returned by ModelConfigRecordStore
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
136
invokeai/backend/model_manager/load/load_default.py
Normal file
136
invokeai/backend/model_manager/load/load_default.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Default implementation of model loading in InvokeAI."""
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Default implementation of ModelLoaderBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
Given a model's configuration as returned by the ModelRecordConfigStore service,
|
||||
return a LoadedModel object that can be used for inference.
|
||||
|
||||
:param model config: Configuration record for this model
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
model_base = self._app_config.models_path
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
|
||||
# This is where the model is actually loaded!
|
||||
with skip_torch_weight_init():
|
||||
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
model=loaded_model,
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> int:
|
||||
"""Get the size of the model on disk."""
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
@@ -3,8 +3,9 @@ from typing import Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from ..util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
GB = 2**30 # 1 GB
|
||||
|
||||
@@ -27,7 +28,7 @@ class MemorySnapshot:
|
||||
self.malloc_info = malloc_info
|
||||
|
||||
@classmethod
|
||||
def capture(cls, run_garbage_collector: bool = True):
|
||||
def capture(cls, run_garbage_collector: bool = True) -> Self:
|
||||
"""Capture and return a MemorySnapshot.
|
||||
|
||||
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
|
||||
@@ -67,7 +68,7 @@ class MemorySnapshot:
|
||||
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
|
||||
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
||||
|
||||
def get_msg_line(prefix: str, val1: int, val2: int):
|
||||
def get_msg_line(prefix: str, val1: int, val2: int) -> str:
|
||||
diff = val2 - val1
|
||||
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
@@ -0,0 +1,188 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""Elements of the cache."""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> CacheStats:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -0,0 +1,407 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||
# used for stats collection
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
|
||||
source_device = cache_entry.model.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, model_size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.print_cuda_stats()
|
||||
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
"""
|
||||
This module implements a system in which model loaders register the
|
||||
type, base and format of models that they know how to load.
|
||||
|
||||
Use like this:
|
||||
|
||||
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model = cls(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
|
||||
class ModelLoaderRegistryBase(ABC):
|
||||
"""This class allows model loaders to register their type, base and format."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""
|
||||
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||
|
||||
Parameters:
|
||||
:param config: Model configuration record, as returned by ModelRecordService
|
||||
:param submodel_type: Submodel to fetch (main models only)
|
||||
:return: tuple(loader_class, model_config, submodel_type)
|
||||
|
||||
Note that the returned model config may be different from one what passed
|
||||
in, in the event that a submodel type is provided.
|
||||
"""
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
return "-".join([base.value, type.value, format.value])
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Init file for model_loaders.
|
||||
"""
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
@@ -0,0 +1,90 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for simple diffusers model loading in InvokeAI."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
result = self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for IP Adapter model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
class IPAdapterInvokeAILoader(ModelLoader):
|
||||
"""Class to load IP Adapter diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
78
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
78
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for LoRA model loading in InvokeAI."""
|
||||
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
# We cheat a little bit to get access to the model base
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache, convert_cache)
|
||||
self._model_base: Optional[BaseModelType] = None
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
||||
path = model_base_path / config.path / f"pytorch_lora_weights.{ext}"
|
||||
if path.exists():
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
42
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
42
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for Onnx model loading in InvokeAI."""
|
||||
|
||||
# This should work the same as Stable Diffusion pipelines
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
return result
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for StableDiffusion model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
variant = config.variant
|
||||
base = config.base
|
||||
pipeline_class = (
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
model_version=base,
|
||||
model_variant=variant,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
return output_path
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
|
||||
)
|
||||
class TextualInversionLoader(ModelLoader):
|
||||
"""Class to load TI models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
path = model_path / "learned_embeds.bin"
|
||||
else:
|
||||
path = model_path
|
||||
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
||||
68
invokeai/backend/model_manager/load/model_loaders/vae.py
Normal file
68
invokeai/backend/model_manager/load/model_loaders/vae.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
)
|
||||
vae_model.to(self._torch_dtype) # set precision appropriately
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user