mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 04:17:59 -05:00
Compare commits
52 Commits
separate-g
...
ryan/dense
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff950bc5cd | ||
|
|
969982b789 | ||
|
|
b8cbff828b | ||
|
|
d3a40c5b2b | ||
|
|
57266d36a2 | ||
|
|
41e1a9f202 | ||
|
|
bcfb43e5f0 | ||
|
|
a665f20fb5 | ||
|
|
d313e5eb70 | ||
|
|
271f8f2414 | ||
|
|
5fad379192 | ||
|
|
ad18429fe3 | ||
|
|
942efa011e | ||
|
|
ffc4ebb14c | ||
|
|
5b3adf0740 | ||
|
|
a5c94fba43 | ||
|
|
3e14bd6c45 | ||
|
|
8721926f14 | ||
|
|
d87ff3a206 | ||
|
|
7d9671014b | ||
|
|
4a1acd4db9 | ||
|
|
8989a6cdc6 | ||
|
|
f44d3da9b1 | ||
|
|
1bbd4f751d | ||
|
|
bdf3691ad0 | ||
|
|
e7f7ae660d | ||
|
|
e132afb705 | ||
|
|
5f49e7ae26 | ||
|
|
53ebca58ff | ||
|
|
ee1b3157ce | ||
|
|
e7ec13f209 | ||
|
|
cad3e5dbd7 | ||
|
|
845c4e93ae | ||
|
|
54971afe44 | ||
|
|
cfba51aed5 | ||
|
|
2966c8de2c | ||
|
|
b0fcbe552e | ||
|
|
d132fb4818 | ||
|
|
2d5d370f38 | ||
|
|
878bbc3527 | ||
|
|
caa690e24d | ||
|
|
38248b988f | ||
|
|
ba4788007f | ||
|
|
ef51005881 | ||
|
|
7b0326d7f7 | ||
|
|
f590b39f88 | ||
|
|
58277c6ada | ||
|
|
382fa57f3b | ||
|
|
ee3abc171d | ||
|
|
bf72cee555 | ||
|
|
e866e3b19f | ||
|
|
16e574825c |
@@ -32,6 +32,7 @@ model. These are the:
|
|||||||
Responsible for loading a model from disk
|
Responsible for loading a model from disk
|
||||||
into RAM and VRAM and getting it ready for inference.
|
into RAM and VRAM and getting it ready for inference.
|
||||||
|
|
||||||
|
|
||||||
## Location of the Code
|
## Location of the Code
|
||||||
|
|
||||||
The four main services can be found in
|
The four main services can be found in
|
||||||
@@ -62,21 +63,23 @@ provides the following fields:
|
|||||||
|----------------|-----------------|------------------|
|
|----------------|-----------------|------------------|
|
||||||
| `key` | str | Unique identifier for the model |
|
| `key` | str | Unique identifier for the model |
|
||||||
| `name` | str | Name of the model (not unique) |
|
| `name` | str | Name of the model (not unique) |
|
||||||
| `model_type` | ModelType | The type of the model |
|
| `model_type` | ModelType | The type of the model |
|
||||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||||
| `path` | str | Location of model on disk |
|
| `path` | str | Location of model on disk |
|
||||||
| `hash` | str | Hash of the model |
|
| `original_hash` | str | Hash of the model when it was first installed |
|
||||||
|
| `current_hash` | str | Most recent hash of the model's contents |
|
||||||
| `description` | str | Human-readable description of the model (optional) |
|
| `description` | str | Human-readable description of the model (optional) |
|
||||||
| `source` | str | Model's source URL or repo id (optional) |
|
| `source` | str | Model's source URL or repo id (optional) |
|
||||||
|
|
||||||
The `key` is a unique 32-character random ID which was generated at
|
The `key` is a unique 32-character random ID which was generated at
|
||||||
install time. The `hash` field stores a hash of the model's
|
install time. The `original_hash` field stores a hash of the model's
|
||||||
contents at install time obtained by sampling several parts of the
|
contents at install time obtained by sampling several parts of the
|
||||||
model's files using the `imohash` library. Over the course of the
|
model's files using the `imohash` library. Over the course of the
|
||||||
model's lifetime it may be transformed in various ways, such as
|
model's lifetime it may be transformed in various ways, such as
|
||||||
changing its precision or converting it from a .safetensors to a
|
changing its precision or converting it from a .safetensors to a
|
||||||
diffusers model.
|
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||||
|
`current_hash` is updated to indicate the current contents.
|
||||||
|
|
||||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||||
@@ -91,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
|
|||||||
to be relative to the `models_dir` setting in the user's
|
to be relative to the `models_dir` setting in the user's
|
||||||
`invokeai.yaml` file.
|
`invokeai.yaml` file.
|
||||||
|
|
||||||
|
|
||||||
### CheckpointConfig
|
### CheckpointConfig
|
||||||
|
|
||||||
This adds support for checkpoint configurations, and adds the
|
This adds support for checkpoint configurations, and adds the
|
||||||
@@ -170,7 +174,7 @@ store = context.services.model_manager.store
|
|||||||
or from elsewhere in the code by accessing
|
or from elsewhere in the code by accessing
|
||||||
`ApiDependencies.invoker.services.model_manager.store`.
|
`ApiDependencies.invoker.services.model_manager.store`.
|
||||||
|
|
||||||
### Creating a `ModelRecordService`
|
### Creating a `ModelRecordService`
|
||||||
|
|
||||||
To create a new `ModelRecordService` database or open an existing one,
|
To create a new `ModelRecordService` database or open an existing one,
|
||||||
you can directly create either a `ModelRecordServiceSQL` or a
|
you can directly create either a `ModelRecordServiceSQL` or a
|
||||||
@@ -213,27 +217,27 @@ for use in the InvokeAI web server. Its signature is:
|
|||||||
```
|
```
|
||||||
def open(
|
def open(
|
||||||
cls,
|
cls,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
conn: Optional[sqlite3.Connection] = None,
|
conn: Optional[sqlite3.Connection] = None,
|
||||||
lock: Optional[threading.Lock] = None
|
lock: Optional[threading.Lock] = None
|
||||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||||
```
|
```
|
||||||
|
|
||||||
The way it works is as follows:
|
The way it works is as follows:
|
||||||
|
|
||||||
1. Retrieve the value of the `model_config_db` option from the user's
|
1. Retrieve the value of the `model_config_db` option from the user's
|
||||||
`invokeai.yaml` config file.
|
`invokeai.yaml` config file.
|
||||||
2. If `model_config_db` is `auto` (the default), then:
|
2. If `model_config_db` is `auto` (the default), then:
|
||||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||||
opened on the passed connection and lock.
|
opened on the passed connection and lock.
|
||||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||||
and/or `lock` are missing (see note below).
|
and/or `lock` are missing (see note below).
|
||||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||||
to return the appropriate type of ModelRecordService.
|
to return the appropriate type of ModelRecordService.
|
||||||
4. If `model_config_db` is None, then retrieve the legacy
|
4. If `model_config_db` is None, then retrieve the legacy
|
||||||
`conf_path` option from `invokeai.yaml` and use the Path
|
`conf_path` option from `invokeai.yaml` and use the Path
|
||||||
indicated there. This will default to `configs/models.yaml`.
|
indicated there. This will default to `configs/models.yaml`.
|
||||||
|
|
||||||
So a typical startup pattern would be:
|
So a typical startup pattern would be:
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -251,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
|||||||
|
|
||||||
Configurations can be retrieved in several ways.
|
Configurations can be retrieved in several ways.
|
||||||
|
|
||||||
#### get_model(key) -> AnyModelConfig
|
#### get_model(key) -> AnyModelConfig:
|
||||||
|
|
||||||
The basic functionality is to call the record store object's
|
The basic functionality is to call the record store object's
|
||||||
`get_model()` method with the desired model's unique key. It returns
|
`get_model()` method with the desired model's unique key. It returns
|
||||||
@@ -268,28 +272,28 @@ print(model_conf.path)
|
|||||||
If the key is unrecognized, this call raises an
|
If the key is unrecognized, this call raises an
|
||||||
`UnknownModelException`.
|
`UnknownModelException`.
|
||||||
|
|
||||||
#### exists(key) -> AnyModelConfig
|
#### exists(key) -> AnyModelConfig:
|
||||||
|
|
||||||
Returns True if a model with the given key exists in the databsae.
|
Returns True if a model with the given key exists in the databsae.
|
||||||
|
|
||||||
#### search_by_path(path) -> AnyModelConfig
|
#### search_by_path(path) -> AnyModelConfig:
|
||||||
|
|
||||||
Returns the configuration of the model whose path is `path`. The path
|
Returns the configuration of the model whose path is `path`. The path
|
||||||
is matched using a simple string comparison and won't correctly match
|
is matched using a simple string comparison and won't correctly match
|
||||||
models referred to by different paths (e.g. using symbolic links).
|
models referred to by different paths (e.g. using symbolic links).
|
||||||
|
|
||||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||||
|
|
||||||
This method searches for models that match some combination of `name`,
|
This method searches for models that match some combination of `name`,
|
||||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||||
all the models in the database.
|
all the models in the database.
|
||||||
|
|
||||||
#### all_models() -> List[AnyModelConfig]
|
#### all_models() -> List[AnyModelConfig]:
|
||||||
|
|
||||||
Return all the model configs in the database. Exactly equivalent to
|
Return all the model configs in the database. Exactly equivalent to
|
||||||
calling `search_by_name()` with no arguments.
|
calling `search_by_name()` with no arguments.
|
||||||
|
|
||||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||||
|
|
||||||
`tags` is a list of strings. This method returns a list of model
|
`tags` is a list of strings. This method returns a list of model
|
||||||
configs that contain all of the given tags. Examples:
|
configs that contain all of the given tags. Examples:
|
||||||
@@ -308,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
|
|||||||
if x.license.contains('allowCommercialUse=Sell')]
|
if x.license.contains('allowCommercialUse=Sell')]
|
||||||
```
|
```
|
||||||
|
|
||||||
#### version() -> str
|
#### version() -> str:
|
||||||
|
|
||||||
Returns the version of the database, currently at `3.2`
|
Returns the version of the database, currently at `3.2`
|
||||||
|
|
||||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||||
|
|
||||||
This method exists to ease the transition from the previous version of
|
This method exists to ease the transition from the previous version of
|
||||||
the model manager, in which `get_model()` took the three arguments
|
the model manager, in which `get_model()` took the three arguments
|
||||||
@@ -333,7 +337,7 @@ model and pass its key to `get_model()`.
|
|||||||
Several methods allow you to create and update stored model config
|
Several methods allow you to create and update stored model config
|
||||||
records.
|
records.
|
||||||
|
|
||||||
#### add_model(key, config) -> AnyModelConfig
|
#### add_model(key, config) -> AnyModelConfig:
|
||||||
|
|
||||||
Given a key and a configuration, this will add the model's
|
Given a key and a configuration, this will add the model's
|
||||||
configuration record to the database. `config` can either be a subclass of
|
configuration record to the database. `config` can either be a subclass of
|
||||||
@@ -348,7 +352,7 @@ model with the same key is already in the database, or an
|
|||||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||||
experienced a parse or validation error.
|
experienced a parse or validation error.
|
||||||
|
|
||||||
### update_model(key, config) -> AnyModelConfig
|
### update_model(key, config) -> AnyModelConfig:
|
||||||
|
|
||||||
Given a key and a configuration, this will update the model
|
Given a key and a configuration, this will update the model
|
||||||
configuration record in the database. `config` can be either a
|
configuration record in the database. `config` can be either a
|
||||||
@@ -366,31 +370,31 @@ The `ModelInstallService` class implements the
|
|||||||
shop for all your model install needs. It provides the following
|
shop for all your model install needs. It provides the following
|
||||||
functionality:
|
functionality:
|
||||||
|
|
||||||
* Registering a model config record for a model already located on the
|
- Registering a model config record for a model already located on the
|
||||||
local filesystem, without moving it or changing its path.
|
local filesystem, without moving it or changing its path.
|
||||||
|
|
||||||
* Installing a model alreadiy located on the local filesystem, by
|
- Installing a model alreadiy located on the local filesystem, by
|
||||||
moving it into the InvokeAI root directory under the
|
moving it into the InvokeAI root directory under the
|
||||||
`models` folder (or wherever config parameter `models_dir`
|
`models` folder (or wherever config parameter `models_dir`
|
||||||
specifies).
|
specifies).
|
||||||
|
|
||||||
* Probing of models to determine their type, base type and other key
|
- Probing of models to determine their type, base type and other key
|
||||||
information.
|
information.
|
||||||
|
|
||||||
* Interface with the InvokeAI event bus to provide status updates on
|
- Interface with the InvokeAI event bus to provide status updates on
|
||||||
the download, installation and registration process.
|
the download, installation and registration process.
|
||||||
|
|
||||||
* Downloading a model from an arbitrary URL and installing it in
|
- Downloading a model from an arbitrary URL and installing it in
|
||||||
`models_dir`.
|
`models_dir`.
|
||||||
|
|
||||||
* Special handling for Civitai model URLs which allow the user to
|
- Special handling for Civitai model URLs which allow the user to
|
||||||
paste in a model page's URL or download link
|
paste in a model page's URL or download link
|
||||||
|
|
||||||
* Special handling for HuggingFace repo_ids to recursively download
|
- Special handling for HuggingFace repo_ids to recursively download
|
||||||
the contents of the repository, paying attention to alternative
|
the contents of the repository, paying attention to alternative
|
||||||
variants such as fp16.
|
variants such as fp16.
|
||||||
|
|
||||||
* Saving tags and other metadata about the model into the invokeai database
|
- Saving tags and other metadata about the model into the invokeai database
|
||||||
when fetching from a repo that provides that type of information,
|
when fetching from a repo that provides that type of information,
|
||||||
(currently only Civitai and HuggingFace).
|
(currently only Civitai and HuggingFace).
|
||||||
|
|
||||||
@@ -423,8 +427,8 @@ queue.start()
|
|||||||
|
|
||||||
installer = ModelInstallService(app_config=config,
|
installer = ModelInstallService(app_config=config,
|
||||||
record_store=record_store,
|
record_store=record_store,
|
||||||
download_queue=queue
|
download_queue=queue
|
||||||
)
|
)
|
||||||
installer.start()
|
installer.start()
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -439,6 +443,7 @@ required parameters:
|
|||||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||||
|
|
||||||
|
|
||||||
Once initialized, the installer will provide the following methods:
|
Once initialized, the installer will provide the following methods:
|
||||||
|
|
||||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||||
@@ -452,15 +457,15 @@ The `source` is a string that can be any of these forms
|
|||||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||||
3. A HuggingFace repo_id with any of the following formats:
|
3. A HuggingFace repo_id with any of the following formats:
|
||||||
* `model/name` -- entire model
|
- `model/name` -- entire model
|
||||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||||
* `model/name::vae` -- vae submodel, using default precision
|
- `model/name::vae` -- vae submodel, using default precision
|
||||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||||
|
|
||||||
Note that by specifying a relative path to the top of the HuggingFace
|
Note that by specifying a relative path to the top of the HuggingFace
|
||||||
repo, you can download and install arbitrary models files.
|
repo, you can download and install arbitrary models files.
|
||||||
|
|
||||||
The variant, if not provided, will be automatically filled in with
|
The variant, if not provided, will be automatically filled in with
|
||||||
`fp32` if the user has requested full precision, and `fp16`
|
`fp32` if the user has requested full precision, and `fp16`
|
||||||
@@ -486,9 +491,9 @@ following illustrates basic usage:
|
|||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.app.services.model_install import (
|
from invokeai.app.services.model_install import (
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||||
@@ -508,13 +513,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
|||||||
source2job = installer.wait_for_installs(timeout=120)
|
source2job = installer.wait_for_installs(timeout=120)
|
||||||
for source in sources:
|
for source in sources:
|
||||||
job = source2job[source]
|
job = source2job[source]
|
||||||
if job.complete:
|
if job.complete:
|
||||||
model_config = job.config_out
|
model_config = job.config_out
|
||||||
model_key = model_config.key
|
model_key = model_config.key
|
||||||
print(f"{source} installed as {model_key}")
|
print(f"{source} installed as {model_key}")
|
||||||
elif job.errored:
|
elif job.errored:
|
||||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
As shown here, the `import_model()` method accepts a variety of
|
As shown here, the `import_model()` method accepts a variety of
|
||||||
@@ -523,7 +528,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
|||||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||||
(but not to folders).
|
(but not to folders).
|
||||||
|
|
||||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||||
an object which tracks the progress of the install.
|
an object which tracks the progress of the install.
|
||||||
|
|
||||||
If a remote model is requested, the model's files are downloaded in
|
If a remote model is requested, the model's files are downloaded in
|
||||||
@@ -550,7 +555,7 @@ The full list of arguments to `import_model()` is as follows:
|
|||||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||||
|
|
||||||
The next few sections describe the various types of ModelSource that
|
The next few sections describe the various types of ModelSource that
|
||||||
can be passed to `import_model()`.
|
can be passed to `import_model()`.
|
||||||
|
|
||||||
`config` can be used to override all or a portion of the configuration
|
`config` can be used to override all or a portion of the configuration
|
||||||
attributes returned by the model prober. See the section below for
|
attributes returned by the model prober. See the section below for
|
||||||
@@ -561,6 +566,7 @@ details.
|
|||||||
This is used for a model that is located on a locally-accessible Posix
|
This is used for a model that is located on a locally-accessible Posix
|
||||||
filesystem, such as a local disk or networked fileshare.
|
filesystem, such as a local disk or networked fileshare.
|
||||||
|
|
||||||
|
|
||||||
| **Argument** | **Type** | **Default** | **Description** |
|
| **Argument** | **Type** | **Default** | **Description** |
|
||||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||||
| `path` | str | Path | None | Path to the model file or directory |
|
| `path` | str | Path | None | Path to the model file or directory |
|
||||||
@@ -619,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
|
|||||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||||
|
|
||||||
|
|
||||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||||
|
|
||||||
The `variant` is one of the various diffusers formats that HuggingFace
|
The `variant` is one of the various diffusers formats that HuggingFace
|
||||||
@@ -654,6 +661,7 @@ in. To download these files, you must provide an
|
|||||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||||
one.
|
one.
|
||||||
|
|
||||||
|
|
||||||
#### Monitoring the install job process
|
#### Monitoring the install job process
|
||||||
|
|
||||||
When you create an install job with `import_model()`, it launches the
|
When you create an install job with `import_model()`, it launches the
|
||||||
@@ -667,13 +675,14 @@ The `ModelInstallJob` class has the following structure:
|
|||||||
| `id` | `int` | Integer ID for this job |
|
| `id` | `int` | Integer ID for this job |
|
||||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||||
| `error` | `str` | Traceback of the error |
|
| `error` | `str` | Traceback of the error |
|
||||||
|
|
||||||
|
|
||||||
If the `event_bus` argument was provided, events will also be
|
If the `event_bus` argument was provided, events will also be
|
||||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||||
@@ -693,13 +702,14 @@ following keys:
|
|||||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||||
|
|
||||||
|
|
||||||
The parts is a list of dictionaries that give information on each of
|
The parts is a list of dictionaries that give information on each of
|
||||||
the components pieces of the download. The dictionary's keys are
|
the components pieces of the download. The dictionary's keys are
|
||||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||||
the like-named keys in the main event.
|
the like-named keys in the main event.
|
||||||
|
|
||||||
Note that downloading events will not be issued for local models, and
|
Note that downloading events will not be issued for local models, and
|
||||||
that downloading events occur _before_ the running event.
|
that downloading events occur *before* the running event.
|
||||||
|
|
||||||
##### `model_install_running`
|
##### `model_install_running`
|
||||||
|
|
||||||
@@ -742,13 +752,14 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
|||||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||||
True if the job is in the complete, errored or cancelled states.
|
True if the job is in the complete, errored or cancelled states.
|
||||||
|
|
||||||
|
|
||||||
#### Model configuration and probing
|
#### Model configuration and probing
|
||||||
|
|
||||||
The install service uses the `invokeai.backend.model_manager.probe`
|
The install service uses the `invokeai.backend.model_manager.probe`
|
||||||
module during import to determine the model's type, base type, and
|
module during import to determine the model's type, base type, and
|
||||||
other configuration parameters. Among other things, it assigns a
|
other configuration parameters. Among other things, it assigns a
|
||||||
default name and description for the model based on probed
|
default name and description for the model based on probed
|
||||||
fields.
|
fields.
|
||||||
|
|
||||||
When downloading remote models is implemented, additional
|
When downloading remote models is implemented, additional
|
||||||
configuration information, such as list of trigger terms, will be
|
configuration information, such as list of trigger terms, will be
|
||||||
@@ -763,11 +774,11 @@ attributes. Here is an example of setting the
|
|||||||
```
|
```
|
||||||
install_job = installer.import_model(
|
install_job = installer.import_model(
|
||||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||||
config=dict(
|
config=dict(
|
||||||
prediction_type=SchedulerPredictionType('v_prediction')
|
prediction_type=SchedulerPredictionType('v_prediction')
|
||||||
name='stable diffusion 2 base model',
|
name='stable diffusion 2 base model',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Other installer methods
|
### Other installer methods
|
||||||
@@ -851,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
|
|||||||
deletes the corresponding model weights file(s), regardless of whether
|
deletes the corresponding model weights file(s), regardless of whether
|
||||||
they are inside or outside the InvokeAI models hierarchy.
|
they are inside or outside the InvokeAI models hierarchy.
|
||||||
|
|
||||||
|
|
||||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||||
|
|
||||||
This utility routine will download the model file located at source,
|
This utility routine will download the model file located at source,
|
||||||
@@ -941,7 +953,7 @@ following fields:
|
|||||||
|
|
||||||
When you create a job, you can assign it a `priority`. If multiple
|
When you create a job, you can assign it a `priority`. If multiple
|
||||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||||
blame me! The Unix developers came up with this convention.)
|
blame me! The Unix developers came up with this convention.)
|
||||||
|
|
||||||
Every job has a `source` and a `destination`. `source` is a string in
|
Every job has a `source` and a `destination`. `source` is a string in
|
||||||
the base class, but subclassses redefine it more specifically.
|
the base class, but subclassses redefine it more specifically.
|
||||||
@@ -962,7 +974,7 @@ is in its lifecycle. Values are defined in the string enum
|
|||||||
`DownloadJobStatus`, a symbol available from
|
`DownloadJobStatus`, a symbol available from
|
||||||
`invokeai.app.services.download_manager`. Possible values are:
|
`invokeai.app.services.download_manager`. Possible values are:
|
||||||
|
|
||||||
| **Value** | **String Value** | **Description** |
|
| **Value** | **String Value** | ** Description ** |
|
||||||
|--------------|---------------------|-------------------|
|
|--------------|---------------------|-------------------|
|
||||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||||
@@ -979,7 +991,7 @@ debugging and performance testing.
|
|||||||
|
|
||||||
In case of an error, the Exception that caused the error will be
|
In case of an error, the Exception that caused the error will be
|
||||||
placed in the `error` field, and the job's status will be set to
|
placed in the `error` field, and the job's status will be set to
|
||||||
`DownloadJobStatus.ERROR`.
|
`DownloadJobStatus.ERROR`.
|
||||||
|
|
||||||
After an error occurs, any partially downloaded files will be deleted
|
After an error occurs, any partially downloaded files will be deleted
|
||||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||||
@@ -1028,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
|
|||||||
periodic intervals. A typical series of events during a successful
|
periodic intervals. A typical series of events during a successful
|
||||||
download session will look like this:
|
download session will look like this:
|
||||||
|
|
||||||
* enqueued
|
- enqueued
|
||||||
* running
|
- running
|
||||||
* running
|
- running
|
||||||
* running
|
- running
|
||||||
* completed
|
- completed
|
||||||
|
|
||||||
There will be a single enqueued event, followed by one or more running
|
There will be a single enqueued event, followed by one or more running
|
||||||
events, and finally one `completed`, `error` or `cancelled`
|
events, and finally one `completed`, `error` or `cancelled`
|
||||||
@@ -1041,12 +1053,12 @@ events.
|
|||||||
It is possible for a caller to pause download temporarily, in which
|
It is possible for a caller to pause download temporarily, in which
|
||||||
case the events may look something like this:
|
case the events may look something like this:
|
||||||
|
|
||||||
* enqueued
|
- enqueued
|
||||||
* running
|
- running
|
||||||
* running
|
- running
|
||||||
* paused
|
- paused
|
||||||
* running
|
- running
|
||||||
* completed
|
- completed
|
||||||
|
|
||||||
The download queue logs when downloads start and end (unless `quiet`
|
The download queue logs when downloads start and end (unless `quiet`
|
||||||
is set to True at initialization time) but doesn't log any progress
|
is set to True at initialization time) but doesn't log any progress
|
||||||
@@ -1108,11 +1120,11 @@ A typical initialization sequence will look like:
|
|||||||
from invokeai.app.services.download_manager import DownloadQueueService
|
from invokeai.app.services.download_manager import DownloadQueueService
|
||||||
|
|
||||||
def log_download_event(job: DownloadJobBase):
|
def log_download_event(job: DownloadJobBase):
|
||||||
logger.info(f'job={job.id}: status={job.status}')
|
logger.info(f'job={job.id}: status={job.status}')
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
event_handlers=[log_download_event]
|
event_handlers=[log_download_event]
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Event handlers can be provided to the queue at initialization time as
|
Event handlers can be provided to the queue at initialization time as
|
||||||
@@ -1143,9 +1155,9 @@ To use the former method, follow this example:
|
|||||||
```
|
```
|
||||||
job = DownloadJobRemoteSource(
|
job = DownloadJobRemoteSource(
|
||||||
source='http://www.civitai.com/models/13456',
|
source='http://www.civitai.com/models/13456',
|
||||||
destination='/tmp/models/',
|
destination='/tmp/models/',
|
||||||
event_handlers=[my_handler1, my_handler2], # if desired
|
event_handlers=[my_handler1, my_handler2], # if desired
|
||||||
)
|
)
|
||||||
queue.submit_download_job(job, start=True)
|
queue.submit_download_job(job, start=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1160,13 +1172,13 @@ To have the queue create the job for you, follow this example instead:
|
|||||||
```
|
```
|
||||||
job = queue.create_download_job(
|
job = queue.create_download_job(
|
||||||
source='http://www.civitai.com/models/13456',
|
source='http://www.civitai.com/models/13456',
|
||||||
destdir='/tmp/models/',
|
destdir='/tmp/models/',
|
||||||
filename='my_model.safetensors',
|
filename='my_model.safetensors',
|
||||||
event_handlers=[my_handler1, my_handler2], # if desired
|
event_handlers=[my_handler1, my_handler2], # if desired
|
||||||
start=True,
|
start=True,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The `filename` argument forces the downloader to use the specified
|
The `filename` argument forces the downloader to use the specified
|
||||||
name for the file rather than the name provided by the remote source,
|
name for the file rather than the name provided by the remote source,
|
||||||
and is equivalent to manually specifying a destination of
|
and is equivalent to manually specifying a destination of
|
||||||
@@ -1175,6 +1187,7 @@ and is equivalent to manually specifying a destination of
|
|||||||
Here is the full list of arguments that can be provided to
|
Here is the full list of arguments that can be provided to
|
||||||
`create_download_job()`:
|
`create_download_job()`:
|
||||||
|
|
||||||
|
|
||||||
| **Argument** | **Type** | **Default** | **Description** |
|
| **Argument** | **Type** | **Default** | **Description** |
|
||||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||||
@@ -1187,7 +1200,7 @@ Here is the full list of arguments that can be provided to
|
|||||||
|
|
||||||
Internally, `create_download_job()` has a little bit of internal logic
|
Internally, `create_download_job()` has a little bit of internal logic
|
||||||
that looks at the type of the source and selects the right subclass of
|
that looks at the type of the source and selects the right subclass of
|
||||||
`DownloadJobBase` to create and enqueue.
|
`DownloadJobBase` to create and enqueue.
|
||||||
|
|
||||||
**TODO**: move this logic into its own method for overriding in
|
**TODO**: move this logic into its own method for overriding in
|
||||||
subclasses.
|
subclasses.
|
||||||
@@ -1262,7 +1275,7 @@ for getting the model to run. For example "author" is metadata, while
|
|||||||
"type", "base" and "format" are not. The latter fields are part of the
|
"type", "base" and "format" are not. The latter fields are part of the
|
||||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||||
|
|
||||||
### Example Usage
|
### Example Usage:
|
||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
@@ -1315,6 +1328,7 @@ This is the common base class for metadata:
|
|||||||
| `author` | str | Model's author |
|
| `author` | str | Model's author |
|
||||||
| `tags` | Set[str] | Model tags |
|
| `tags` | Set[str] | Model tags |
|
||||||
|
|
||||||
|
|
||||||
Note that the model config record also has a `name` field. It is
|
Note that the model config record also has a `name` field. It is
|
||||||
intended that the config record version be locally customizable, while
|
intended that the config record version be locally customizable, while
|
||||||
the metadata version is read-only. However, enforcing this is expected
|
the metadata version is read-only. However, enforcing this is expected
|
||||||
@@ -1334,6 +1348,7 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
|||||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||||
| `files` | List[Path] | List of the files in the model repo |
|
| `files` | List[Path] | List of the files in the model repo |
|
||||||
|
|
||||||
|
|
||||||
#### `CivitaiMetadata`
|
#### `CivitaiMetadata`
|
||||||
|
|
||||||
This descends from `ModelMetadataBase` and adds the following fields:
|
This descends from `ModelMetadataBase` and adds the following fields:
|
||||||
@@ -1400,6 +1415,7 @@ testing suite to avoid hitting the internet.
|
|||||||
The HuggingFace and Civitai fetcher subclasses add additional
|
The HuggingFace and Civitai fetcher subclasses add additional
|
||||||
repo-specific fetching methods:
|
repo-specific fetching methods:
|
||||||
|
|
||||||
|
|
||||||
#### HuggingFaceMetadataFetch
|
#### HuggingFaceMetadataFetch
|
||||||
|
|
||||||
This overrides its base class `from_json()` method to return a
|
This overrides its base class `from_json()` method to return a
|
||||||
@@ -1418,12 +1434,13 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
|||||||
only difference is that it returna a `CivitaiMetadata` object rather
|
only difference is that it returna a `CivitaiMetadata` object rather
|
||||||
than an `AnyModelRepoMetadata`.
|
than an `AnyModelRepoMetadata`.
|
||||||
|
|
||||||
|
|
||||||
### Metadata Storage
|
### Metadata Storage
|
||||||
|
|
||||||
The `ModelMetadataStore` provides a simple facility to store model
|
The `ModelMetadataStore` provides a simple facility to store model
|
||||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||||
to be searchable.
|
to be searchable.
|
||||||
|
|
||||||
When a metadata object is saved to the database, it is identified
|
When a metadata object is saved to the database, it is identified
|
||||||
using the model key, _and this key must correspond to an existing
|
using the model key, _and this key must correspond to an existing
|
||||||
@@ -1518,16 +1535,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
|||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(
|
convert_cache = ModelConvertCache(
|
||||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||||
)
|
)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
app_config=config,
|
app_config=config,
|
||||||
ram_cache=ram_cache,
|
ram_cache=ram_cache,
|
||||||
convert_cache=convert_cache,
|
convert_cache=convert_cache,
|
||||||
registry=ModelLoaderRegistry
|
registry=ModelLoaderRegistry
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1550,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
|
|||||||
record returned by the model record `get_model()` method, as well as
|
record returned by the model record `get_model()` method, as well as
|
||||||
the in-memory loaded model:
|
the in-memory loaded model:
|
||||||
|
|
||||||
|
|
||||||
| **Attribute Name** | **Type** | **Description** |
|
| **Attribute Name** | **Type** | **Description** |
|
||||||
|----------------|-----------------|------------------|
|
|----------------|-----------------|------------------|
|
||||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||||
@@ -1563,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
|||||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||||
models. The others are obvious.
|
models. The others are obvious.
|
||||||
|
|
||||||
|
|
||||||
`LoadedModel` acts as a context manager. The context loads the model
|
`LoadedModel` acts as a context manager. The context loads the model
|
||||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||||
in the execution device for the duration of the context, and returns
|
in the execution device for the duration of the context, and returns
|
||||||
@@ -1571,14 +1590,14 @@ the model. Use it like this:
|
|||||||
```
|
```
|
||||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
with model_info as vae:
|
with model_info as vae:
|
||||||
image = vae.decode(latents)[0]
|
image = vae.decode(latents)[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
`get_model_by_key()` may raise any of the following exceptions:
|
`get_model_by_key()` may raise any of the following exceptions:
|
||||||
|
|
||||||
* `UnknownModelException` -- key not in database
|
- `UnknownModelException` -- key not in database
|
||||||
* `ModelNotFoundException` -- key in database but model not found at path
|
- `ModelNotFoundException` -- key in database but model not found at path
|
||||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||||
|
|
||||||
### Emitting model loading events
|
### Emitting model loading events
|
||||||
|
|
||||||
@@ -1590,15 +1609,15 @@ following payload:
|
|||||||
|
|
||||||
```
|
```
|
||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_key=model_key,
|
model_key=model_key,
|
||||||
submodel_type=submodel,
|
submodel_type=submodel,
|
||||||
hash=model_info.hash,
|
hash=model_info.hash,
|
||||||
location=str(model_info.location),
|
location=str(model_info.location),
|
||||||
precision=str(model_info.precision),
|
precision=str(model_info.precision),
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1705,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
|
|||||||
invocation.
|
invocation.
|
||||||
|
|
||||||
In the examples below, we have retrieved the manager using:
|
In the examples below, we have retrieved the manager using:
|
||||||
|
|
||||||
```
|
```
|
||||||
mm = ApiDependencies.invoker.services.model_manager
|
mm = ApiDependencies.invoker.services.model_manager
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
|
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
@@ -92,9 +93,10 @@ class ApiDependencies:
|
|||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
|
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
model_record_service=ModelRecordServiceSQL(db=db),
|
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||||
download_queue=download_queue_service,
|
download_queue=download_queue_service,
|
||||||
events=events,
|
events=events,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Any, Dict, List, Optional
|
from hashlib import sha1
|
||||||
|
from random import randbytes
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@@ -12,11 +14,15 @@ from starlette.exceptions import HTTPException
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
|
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
|
ModelRecordOrderBy,
|
||||||
|
ModelSummary,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@@ -25,6 +31,9 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@@ -40,6 +49,15 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTagSet(BaseModel):
|
||||||
|
"""Return tags for a set of models."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
name: str
|
||||||
|
author: str
|
||||||
|
tags: Set[str]
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# These are example inputs and outputs that are used in places where Swagger
|
# These are example inputs and outputs that are used in places where Swagger
|
||||||
# is unable to generate a correct example.
|
# is unable to generate a correct example.
|
||||||
@@ -50,16 +68,19 @@ example_model_config = {
|
|||||||
"base": "sd-1",
|
"base": "sd-1",
|
||||||
"type": "main",
|
"type": "main",
|
||||||
"format": "checkpoint",
|
"format": "checkpoint",
|
||||||
"config_path": "string",
|
"config": "string",
|
||||||
"key": "string",
|
"key": "string",
|
||||||
"hash": "string",
|
"original_hash": "string",
|
||||||
|
"current_hash": "string",
|
||||||
"description": "string",
|
"description": "string",
|
||||||
"source": "string",
|
"source": "string",
|
||||||
"converted_at": 0,
|
"last_modified": 0,
|
||||||
|
"vae": "string",
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
"prediction_type": "epsilon",
|
"prediction_type": "epsilon",
|
||||||
"repo_variant": "fp16",
|
"repo_variant": "fp16",
|
||||||
"upcast_attention": False,
|
"upcast_attention": False,
|
||||||
|
"ztsnr_training": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
example_model_input = {
|
example_model_input = {
|
||||||
@@ -68,12 +89,50 @@ example_model_input = {
|
|||||||
"base": "sd-1",
|
"base": "sd-1",
|
||||||
"type": "main",
|
"type": "main",
|
||||||
"format": "checkpoint",
|
"format": "checkpoint",
|
||||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||||
"description": "Model description",
|
"description": "Model description",
|
||||||
"vae": None,
|
"vae": None,
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
example_model_metadata = {
|
||||||
|
"name": "ip_adapter_sd_image_encoder",
|
||||||
|
"author": "InvokeAI",
|
||||||
|
"tags": [
|
||||||
|
"transformers",
|
||||||
|
"safetensors",
|
||||||
|
"clip_vision_model",
|
||||||
|
"endpoints_compatible",
|
||||||
|
"region:us",
|
||||||
|
"has_space",
|
||||||
|
"license:apache-2.0",
|
||||||
|
],
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||||
|
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||||
|
"size": 628,
|
||||||
|
"sha256": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||||
|
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||||
|
"size": 560,
|
||||||
|
"sha256": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||||
|
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||||
|
"size": 2528373448,
|
||||||
|
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"type": "huggingface",
|
||||||
|
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||||
|
"tag_dict": {"license": "apache-2.0"},
|
||||||
|
"last_modified": "2023-09-23T17:33:25Z",
|
||||||
|
}
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# ROUTES
|
# ROUTES
|
||||||
##############################################################################
|
##############################################################################
|
||||||
@@ -153,16 +212,89 @@ async def get_model_record(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||||
# async def list_model_summary(
|
async def list_model_summary(
|
||||||
# page: int = Query(default=0, description="The page to get"),
|
page: int = Query(default=0, description="The page to get"),
|
||||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
per_page: int = Query(default=10, description="The number of models per page"),
|
||||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||||
# ) -> PaginatedResults[ModelSummary]:
|
) -> PaginatedResults[ModelSummary]:
|
||||||
# """Gets a page of model summary data."""
|
"""Gets a page of model summary data."""
|
||||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||||
# return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/i/{key}/metadata",
|
||||||
|
operation_id="get_model_metadata",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "The model metadata was retrieved successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_metadata}},
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_model_metadata(
|
||||||
|
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||||
|
) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""Get a model metadata object."""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.patch(
|
||||||
|
"/i/{key}/metadata",
|
||||||
|
operation_id="update_model_metadata",
|
||||||
|
responses={
|
||||||
|
201: {
|
||||||
|
"description": "The model metadata was updated successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_metadata}},
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_model_metadata(
|
||||||
|
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||||
|
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||||
|
) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""Updates or creates a model metadata object."""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_metadata = record_store.get_metadata(key)
|
||||||
|
if original_metadata:
|
||||||
|
if changes.default_settings:
|
||||||
|
original_metadata.default_settings = changes.default_settings
|
||||||
|
|
||||||
|
metadata_store.update_metadata(key, original_metadata)
|
||||||
|
else:
|
||||||
|
metadata_store.add_metadata(
|
||||||
|
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An error occurred while updating the model metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/tags",
|
||||||
|
operation_id="list_tags",
|
||||||
|
)
|
||||||
|
async def list_tags() -> Set[str]:
|
||||||
|
"""Get a unique set of all the model tags."""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
result: Set[str] = record_store.list_tags()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class FoundModel(BaseModel):
|
class FoundModel(BaseModel):
|
||||||
@@ -234,6 +366,19 @@ async def scan_for_models(
|
|||||||
return scan_results
|
return scan_results
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/tags/search",
|
||||||
|
operation_id="search_by_metadata_tags",
|
||||||
|
)
|
||||||
|
async def search_by_metadata_tags(
|
||||||
|
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||||
|
) -> ModelsList:
|
||||||
|
"""Get a list of models."""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
results = record_store.search_by_metadata_tag(tags)
|
||||||
|
return ModelsList(models=results)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
@model_manager_router.patch(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
@@ -250,13 +395,15 @@ async def scan_for_models(
|
|||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
info: Annotated[
|
||||||
|
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
|
],
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update a model's config."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
try:
|
try:
|
||||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||||
logger.info(f"Updated model: {key}")
|
logger.info(f"Updated model: {key}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@@ -268,14 +415,14 @@ async def update_model_record(
|
|||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="delete_model",
|
operation_id="del_model_record",
|
||||||
responses={
|
responses={
|
||||||
204: {"description": "Model deleted successfully"},
|
204: {"description": "Model deleted successfully"},
|
||||||
404: {"description": "Model not found"},
|
404: {"description": "Model not found"},
|
||||||
},
|
},
|
||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def del_model_record(
|
||||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
@@ -296,39 +443,42 @@ async def delete_model(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# @model_manager_router.post(
|
@model_manager_router.post(
|
||||||
# "/i/",
|
"/i/",
|
||||||
# operation_id="add_model_record",
|
operation_id="add_model_record",
|
||||||
# responses={
|
responses={
|
||||||
# 201: {
|
201: {
|
||||||
# "description": "The model added successfully",
|
"description": "The model added successfully",
|
||||||
# "content": {"application/json": {"example": example_model_config}},
|
"content": {"application/json": {"example": example_model_config}},
|
||||||
# },
|
},
|
||||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
# 415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
# },
|
},
|
||||||
# status_code=201,
|
status_code=201,
|
||||||
# )
|
)
|
||||||
# async def add_model_record(
|
async def add_model_record(
|
||||||
# config: Annotated[
|
config: Annotated[
|
||||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
# ],
|
],
|
||||||
# ) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
# """Add a model using the configuration information appropriate for its type."""
|
"""Add a model using the configuration information appropriate for its type."""
|
||||||
# logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
# try:
|
if config.key == "<NOKEY>":
|
||||||
# record_store.add_model(config)
|
config.key = sha1(randbytes(100)).hexdigest()
|
||||||
# except DuplicateModelException as e:
|
logger.info(f"Created model {config.key} for {config.name}")
|
||||||
# logger.error(str(e))
|
try:
|
||||||
# raise HTTPException(status_code=409, detail=str(e))
|
record_store.add_model(config.key, config)
|
||||||
# except InvalidModelException as e:
|
except DuplicateModelException as e:
|
||||||
# logger.error(str(e))
|
logger.error(str(e))
|
||||||
# raise HTTPException(status_code=415)
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
# # now fetch it out
|
# now fetch it out
|
||||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
result: AnyModelConfig = record_store.get_model(config.key)
|
||||||
# return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.post(
|
@model_manager_router.post(
|
||||||
@@ -403,10 +553,10 @@ async def install_model(
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/install",
|
"/import",
|
||||||
operation_id="list_model_installs",
|
operation_id="list_model_install_jobs",
|
||||||
)
|
)
|
||||||
async def list_model_installs() -> List[ModelInstallJob]:
|
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||||
"""Return the list of model install jobs.
|
"""Return the list of model install jobs.
|
||||||
|
|
||||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||||
@@ -420,8 +570,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
|||||||
* "cancelled" -- Job was cancelled before completion.
|
* "cancelled" -- Job was cancelled before completion.
|
||||||
|
|
||||||
Once completed, information about the model such as its size, base
|
Once completed, information about the model such as its size, base
|
||||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
model, type, and metadata can be retrieved from the `config_out`
|
||||||
information on individual files can be retrieved from `download_parts`.
|
field. For multi-file models such as diffusers, information on individual files
|
||||||
|
can be retrieved from `download_parts`.
|
||||||
|
|
||||||
See the example and schema below for more information.
|
See the example and schema below for more information.
|
||||||
"""
|
"""
|
||||||
@@ -430,7 +581,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/install/{id}",
|
"/import/{id}",
|
||||||
operation_id="get_model_install_job",
|
operation_id="get_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Success"},
|
200: {"description": "Success"},
|
||||||
@@ -450,7 +601,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/install/{id}",
|
"/import/{id}",
|
||||||
operation_id="cancel_model_install_job",
|
operation_id="cancel_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The job was cancelled successfully"},
|
201: {"description": "The job was cancelled successfully"},
|
||||||
@@ -468,8 +619,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
|||||||
installer.cancel_job(job)
|
installer.cancel_job(job)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.patch(
|
||||||
"/install",
|
"/import",
|
||||||
operation_id="prune_model_install_jobs",
|
operation_id="prune_model_install_jobs",
|
||||||
responses={
|
responses={
|
||||||
204: {"description": "All completed and errored jobs have been pruned"},
|
204: {"description": "All completed and errored jobs have been pruned"},
|
||||||
@@ -548,8 +699,7 @@ async def convert_model(
|
|||||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||||
original_name = model_config.name
|
original_name = model_config.name
|
||||||
model_config.name = f"{original_name}.DELETE"
|
model_config.name = f"{original_name}.DELETE"
|
||||||
changes = ModelRecordChanges(name=model_config.name)
|
store.update_model(key, config=model_config)
|
||||||
store.update_model(key, changes=changes)
|
|
||||||
|
|
||||||
# install the diffusers
|
# install the diffusers
|
||||||
try:
|
try:
|
||||||
@@ -558,7 +708,7 @@ async def convert_model(
|
|||||||
config={
|
config={
|
||||||
"name": original_name,
|
"name": original_name,
|
||||||
"description": model_config.description,
|
"description": model_config.description,
|
||||||
"hash": model_config.hash,
|
"original_hash": model_config.original_hash,
|
||||||
"source": model_config.source,
|
"source": model_config.source,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -566,6 +716,10 @@ async def convert_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
# get the original metadata
|
||||||
|
if orig_metadata := store.get_metadata(key):
|
||||||
|
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||||
|
|
||||||
# delete the original safetensors file
|
# delete the original safetensors file
|
||||||
installer.delete(key)
|
installer.delete(key)
|
||||||
|
|
||||||
@@ -577,66 +731,66 @@ async def convert_model(
|
|||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
# @model_manager_router.put(
|
@model_manager_router.put(
|
||||||
# "/merge",
|
"/merge",
|
||||||
# operation_id="merge",
|
operation_id="merge",
|
||||||
# responses={
|
responses={
|
||||||
# 200: {
|
200: {
|
||||||
# "description": "Model converted successfully",
|
"description": "Model converted successfully",
|
||||||
# "content": {"application/json": {"example": example_model_config}},
|
"content": {"application/json": {"example": example_model_config}},
|
||||||
# },
|
},
|
||||||
# 400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
# 404: {"description": "Model not found"},
|
404: {"description": "Model not found"},
|
||||||
# 409: {"description": "There is already a model registered at this location"},
|
409: {"description": "There is already a model registered at this location"},
|
||||||
# },
|
},
|
||||||
# )
|
)
|
||||||
# async def merge(
|
async def merge(
|
||||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
# force: bool = Body(
|
force: bool = Body(
|
||||||
# description="Force merging of models created with different versions of diffusers",
|
description="Force merging of models created with different versions of diffusers",
|
||||||
# default=False,
|
default=False,
|
||||||
# ),
|
),
|
||||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||||
# merge_dest_directory: Optional[str] = Body(
|
merge_dest_directory: Optional[str] = Body(
|
||||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
# default=None,
|
default=None,
|
||||||
# ),
|
),
|
||||||
# ) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
# """
|
"""
|
||||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||||
# ```
|
```
|
||||||
# Argument Description [default]
|
Argument Description [default]
|
||||||
# -------- ----------------------
|
-------- ----------------------
|
||||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||||
# merged_model_name Name for the merged model [Concat model names]
|
merged_model_name Name for the merged model [Concat model names]
|
||||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||||
# ```
|
```
|
||||||
# """
|
"""
|
||||||
# logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
# try:
|
try:
|
||||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
# merger = ModelMerger(installer)
|
merger = ModelMerger(installer)
|
||||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||||
# response = merger.merge_diffusion_models_and_save(
|
response = merger.merge_diffusion_models_and_save(
|
||||||
# model_keys=keys,
|
model_keys=keys,
|
||||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
# alpha=alpha,
|
alpha=alpha,
|
||||||
# interp=interp,
|
interp=interp,
|
||||||
# force=force,
|
force=force,
|
||||||
# merge_dest_directory=dest,
|
merge_dest_directory=dest,
|
||||||
# )
|
)
|
||||||
# except UnknownModelException:
|
except UnknownModelException:
|
||||||
# raise HTTPException(
|
raise HTTPException(
|
||||||
# status_code=404,
|
status_code=404,
|
||||||
# detail=f"One or more of the models '{keys}' not found",
|
detail=f"One or more of the models '{keys}' not found",
|
||||||
# )
|
)
|
||||||
# except ValueError as e:
|
except ValueError as e:
|
||||||
# raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# return response
|
return response
|
||||||
|
|||||||
@@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
MaskField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@@ -36,7 +44,7 @@ from .model import ClipField
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
mask_weight=self.mask_weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
@@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
@@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
mask_weight=self.mask_weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
|
|||||||
40
invokeai/app/invocations/conditioning.py
Normal file
40
invokeai/app/invocations/conditioning.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
InvocationContext,
|
||||||
|
invocation,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||||
|
from invokeai.app.invocations.primitives import MaskField, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"rectangle_mask",
|
||||||
|
title="Create Rectangle Mask",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
|
"""Create a rectangular mask."""
|
||||||
|
|
||||||
|
height: int = InputField(description="The height of the entire mask.")
|
||||||
|
width: int = InputField(description="The width of the entire mask.")
|
||||||
|
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||||
|
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||||
|
mask[
|
||||||
|
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||||
|
] = True
|
||||||
|
|
||||||
|
mask_name = context.tensors.save(mask)
|
||||||
|
return MaskOutput(
|
||||||
|
mask=MaskField(mask_name=mask_name),
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
)
|
||||||
@@ -194,6 +194,12 @@ class BoardField(BaseModel):
|
|||||||
board_id: str = Field(description="The id of the board")
|
board_id: str = Field(description="The id of the board")
|
||||||
|
|
||||||
|
|
||||||
|
class MaskField(BaseModel):
|
||||||
|
"""A mask primitive field."""
|
||||||
|
|
||||||
|
mask_name: str = Field(description="The name of the mask.")
|
||||||
|
|
||||||
|
|
||||||
class DenoiseMaskField(BaseModel):
|
class DenoiseMaskField(BaseModel):
|
||||||
"""An inpaint mask field"""
|
"""An inpaint mask field"""
|
||||||
|
|
||||||
@@ -225,7 +231,12 @@ class ConditioningField(BaseModel):
|
|||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
# endregion
|
mask: Optional[MaskField] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
|
"included regions should be set to True.",
|
||||||
|
)
|
||||||
|
mask_weight: float = Field(description="")
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
class MetadataField(RootModel):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
@@ -9,6 +9,7 @@ import einops
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
@@ -55,7 +56,14 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
Range,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@@ -65,7 +73,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
T2IAdapterData,
|
T2IAdapterData,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@@ -284,11 +291,11 @@ def get_scheduler(
|
|||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(
|
noise: Optional[LatentsField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -365,39 +372,190 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
self,
|
||||||
|
cond_list: list[ConditioningField],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.mask_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
mask = tf(mask)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def concat_regional_text_embeddings(
|
||||||
|
self,
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
conditioning_fields: list[ConditioningField],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
extra_conditioning = None
|
||||||
|
|
||||||
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
|
mask = masks[prompt_idx]
|
||||||
|
if (
|
||||||
|
text_embedding_info.extra_conditioning is not None
|
||||||
|
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
|
):
|
||||||
|
extra_conditioning = text_embedding_info.extra_conditioning
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||||
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||||
|
# mask, but we don't enforce this.
|
||||||
|
|
||||||
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||||
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||||
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
if pooled_embedding is None or mask is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None or mask is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
# embedding_ranges.append(
|
||||||
|
# Range(
|
||||||
|
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
|
||||||
|
# in the count.
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len + 1,
|
||||||
|
end=cur_text_embedding_len
|
||||||
|
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(
|
||||||
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
|
ranges=embedding_ranges,
|
||||||
|
mask_weights=[x.mask_weight for x in conditioning_fields],
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||||
|
"prompts."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
pooled_embeds=pooled_embedding,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
), regions
|
||||||
|
return BasicConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler: Scheduler,
|
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
seed: int,
|
latent_height: int,
|
||||||
) -> ConditioningData:
|
latent_width: int,
|
||||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
) -> TextConditioningData:
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||||
|
cond_list = self.positive_conditioning
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
uncond_list = self.negative_conditioning
|
||||||
|
if not isinstance(uncond_list, list):
|
||||||
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
cond_list, context, unet.device, unet.dtype
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
unconditioned_embeddings=uc,
|
|
||||||
text_embeddings=c,
|
|
||||||
guidance_scale=self.cfg_scale,
|
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=0.0, # threshold,
|
|
||||||
warmup=0.2, # warmup,
|
|
||||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
scheduler,
|
uncond_list, context, unet.device, unet.dtype
|
||||||
# for ddim scheduler
|
)
|
||||||
eta=0.0, # ddim_eta
|
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||||
# for ancestral and sde schedulers
|
text_conditionings=cond_text_embeddings,
|
||||||
# flip all bits to have noise different from initial
|
masks=cond_text_embedding_masks,
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
conditioning_fields=cond_list,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
conditioning_fields=uncond_list,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
conditioning_data = TextConditioningData(
|
||||||
|
uncond_text=uncond_text_embedding,
|
||||||
|
cond_text=cond_text_embedding,
|
||||||
|
uncond_regions=uncond_regions,
|
||||||
|
cond_regions=cond_regions,
|
||||||
|
guidance_scale=self.cfg_scale,
|
||||||
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
@@ -503,7 +661,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
) -> Optional[list[IPAdapterData]]:
|
) -> Optional[list[IPAdapterData]]:
|
||||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||||
@@ -520,7 +677,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
ip_adapter_data_list = []
|
ip_adapter_data_list = []
|
||||||
conditioning_data.ip_adapter_conditioning = []
|
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||||
@@ -543,16 +699,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_adapter_data_list.append(
|
ip_adapter_data_list.append(
|
||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
ip_adapter_model=ip_adapter_model,
|
ip_adapter_model=ip_adapter_model,
|
||||||
weight=single_ip_adapter.weight,
|
weight=single_ip_adapter.weight,
|
||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -642,6 +795,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
steps: int,
|
steps: int,
|
||||||
denoising_start: float,
|
denoising_start: float,
|
||||||
denoising_end: float,
|
denoising_end: float,
|
||||||
|
seed: int,
|
||||||
) -> Tuple[int, List[int], int]:
|
) -> Tuple[int, List[int], int]:
|
||||||
assert isinstance(scheduler, ConfigMixin)
|
assert isinstance(scheduler, ConfigMixin)
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
@@ -670,7 +824,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
num_inference_steps = len(timesteps) // scheduler.order
|
num_inference_steps = len(timesteps) // scheduler.order
|
||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
scheduler_step_kwargs = {}
|
||||||
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
|
# reproducibility.
|
||||||
|
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||||
|
|
||||||
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
def prep_inpaint_mask(
|
def prep_inpaint_mask(
|
||||||
self, context: InvocationContext, latents: torch.Tensor
|
self, context: InvocationContext, latents: torch.Tensor
|
||||||
@@ -763,7 +925,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
context=context,
|
context=context,
|
||||||
@@ -777,16 +942,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context=context,
|
context=context,
|
||||||
ip_adapter=self.ip_adapter,
|
ip_adapter=self.ip_adapter,
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
denoising_start=self.denoising_start,
|
denoising_start=self.denoising_start,
|
||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
@@ -799,6 +964,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
gradient_mask=gradient_mask,
|
gradient_mask=gradient_mask,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data,
|
control_data=controlnet_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel_type=SubModelType.VAE,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
LatentsField,
|
LatentsField,
|
||||||
|
MaskField,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
@@ -229,6 +230,18 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
# region Image
|
# region Image
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("mask_output")
|
||||||
|
class MaskOutput(BaseInvocationOutput):
|
||||||
|
"""A torch mask tensor.
|
||||||
|
dtype: torch.bool
|
||||||
|
shape: (1, height, width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask: MaskField = OutputField(description="The mask.")
|
||||||
|
width: int = OutputField(description="The width of the mask in pixels.")
|
||||||
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_output")
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
@@ -414,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
|
||||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_collection_output")
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel_type=SubModelType.VAE,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel_type=SubModelType.VAE,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -256,7 +256,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.config import ModelSourceType
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
class InstallStatus(str, Enum):
|
class InstallStatus(str, Enum):
|
||||||
"""State of an install job running in the background."""
|
"""State of an install job running in the background."""
|
||||||
@@ -150,13 +151,6 @@ ModelSource = Annotated[
|
|||||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
|
||||||
URLModelSource: ModelSourceType.Url,
|
|
||||||
HFModelSource: ModelSourceType.HFRepoID,
|
|
||||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
|
||||||
LocalModelSource: ModelSourceType.Path,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallJob(BaseModel):
|
class ModelInstallJob(BaseModel):
|
||||||
"""Object that tracks the current status of an install request."""
|
"""Object that tracks the current status of an install request."""
|
||||||
@@ -266,6 +260,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
|
metadata_store: ModelMetadataStoreBase,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -352,7 +347,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
source: str,
|
source: str,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||||
|
|
||||||
@@ -398,7 +392,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
will override corresponding autoassigned probe fields in the
|
will override corresponding autoassigned probe fields in the
|
||||||
model's config record. Use it to override
|
model's config record. Use it to override
|
||||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||||
`prediction_type`, and/or `image_size`.
|
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||||
|
|
||||||
This will download the model located at `source`,
|
This will download the model located at `source`,
|
||||||
probe it, and install it into the models directory.
|
probe it, and install it into the models directory.
|
||||||
|
|||||||
@@ -20,15 +20,12 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
|||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
CheckpointConfigBase,
|
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelSourceType,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
@@ -38,14 +35,12 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
|
||||||
CivitaiModelSource,
|
CivitaiModelSource,
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
@@ -95,6 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
|
self._metadata_store = record_store.metadata_store # for convenience
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
@@ -143,7 +139,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config = config or {}
|
config = config or {}
|
||||||
if not config.get("source"):
|
if not config.get("source"):
|
||||||
config["source"] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
config["source_type"] = ModelSourceType.Path
|
|
||||||
return self._register(model_path, config)
|
return self._register(model_path, config)
|
||||||
|
|
||||||
def install_path(
|
def install_path(
|
||||||
@@ -153,11 +148,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
if not config.get("source"):
|
||||||
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
config["key"] = config.get("key", uuid_string())
|
||||||
|
|
||||||
if self._app_config.skip_model_hash:
|
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
|
||||||
|
|
||||||
if preferred_name := config.get("name"):
|
if preferred_name := config.get("name"):
|
||||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||||
@@ -183,7 +178,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source: str,
|
source: str,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: bool = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
@@ -379,18 +374,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.bytes = job.total_bytes
|
job.bytes = job.total_bytes
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
job.config_in["source"] = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
|
||||||
# enter the metadata, if there is any
|
|
||||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
|
||||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
|
||||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
|
||||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
|
||||||
|
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
else:
|
else:
|
||||||
key = self.install_path(job.local_path, job.config_in)
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
|
||||||
|
# enter the metadata, if there is any
|
||||||
|
if job.source_metadata:
|
||||||
|
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@@ -476,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
new_path = self._move_model(old_path, new_path)
|
new_path = self._move_model(old_path, new_path)
|
||||||
model.path = new_path.relative_to(models_dir).as_posix()
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
self.record_store.update_model(key, model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _scan_register(self, model: Path) -> bool:
|
def _scan_register(self, model: Path) -> bool:
|
||||||
@@ -528,15 +520,24 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
|
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||||
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||||
|
if config: # used to override probe fields
|
||||||
|
for key, value in config.items():
|
||||||
|
setattr(info, key, value)
|
||||||
|
return info
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
config = config or {}
|
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||||
|
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||||
if self._app_config.skip_model_hash:
|
config["key"] = config.get("key", uuid_string())
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
|
assert info.original_hash # always assigned by probe()
|
||||||
|
info.key = override_key or info.original_hash
|
||||||
|
|
||||||
model_path = model_path.absolute()
|
model_path = model_path.absolute()
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
@@ -545,11 +546,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
info.path = model_path.as_posix()
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
# add 'main' specific fields
|
# add 'main' specific fields
|
||||||
if isinstance(info, CheckpointConfigBase):
|
if hasattr(info, "config"):
|
||||||
# make config relative to our root
|
# make config relative to our root
|
||||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
self.record_store.add_model(info)
|
self.record_store.add_model(info.key, info)
|
||||||
return info.key
|
return info.key
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
def _next_id(self) -> int:
|
||||||
@@ -570,15 +571,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
config_in=config or {},
|
config_in=config or {},
|
||||||
local_path=Path(source.path),
|
local_path=Path(source.path),
|
||||||
inplace=source.inplace or False,
|
inplace=source.inplace,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
if not source.access_token:
|
if not source.access_token:
|
||||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||||
str(source.version_id)
|
|
||||||
)
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||||
@@ -606,17 +605,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# URLs from Civitai or HuggingFace will be handled specially
|
# URLs from Civitai or HuggingFace will be handled specially
|
||||||
|
url_patterns = {
|
||||||
|
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||||
|
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||||
|
}
|
||||||
metadata = None
|
metadata = None
|
||||||
fetcher = None
|
for pattern, fetcher in url_patterns.items():
|
||||||
try:
|
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
metadata = fetcher(self._session).from_url(source.url)
|
||||||
except ValueError:
|
break
|
||||||
pass
|
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
|
||||||
if fetcher is CivitaiMetadataFetch:
|
|
||||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
|
||||||
if fetcher is not None:
|
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
|
||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
@@ -631,7 +628,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_remote_model(
|
def _import_remote_model(
|
||||||
self,
|
self,
|
||||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
source: ModelSource,
|
||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
@@ -659,7 +656,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# In the event that there is a subfolder specified in the source,
|
# In the event that there is a subfolder specified in the source,
|
||||||
# we need to remove it from the destination path in order to avoid
|
# we need to remove it from the destination path in order to avoid
|
||||||
# creating unwanted subfolders
|
# creating unwanted subfolders
|
||||||
if isinstance(source, HFModelSource) and source.subfolder:
|
if hasattr(source, "subfolder") and source.subfolder:
|
||||||
root = Path(remote_files[0].path.parts[0])
|
root = Path(remote_files[0].path.parts[0])
|
||||||
subfolder = root / source.subfolder
|
subfolder = root / source.subfolder
|
||||||
else:
|
else:
|
||||||
@@ -846,11 +843,3 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_fetcher_from_url(url: str):
|
|
||||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
|
||||||
return CivitaiMetadataFetch
|
|
||||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
|
||||||
return HuggingFaceMetadataFetch
|
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
|
||||||
|
|||||||
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Init file for ModelMetadataStoreService module."""
|
||||||
|
|
||||||
|
from .metadata_store_base import ModelMetadataStoreBase
|
||||||
|
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelMetadataStoreBase",
|
||||||
|
"ModelMetadataStoreSQL",
|
||||||
|
]
|
||||||
81
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
81
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
Storage for Model Metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
|
"""A set of changes to apply to model metadata.
|
||||||
|
Only limited changes are valid:
|
||||||
|
- `default_settings`: the user-configured default settings for this model
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
|
default=None, description="The user-configured default settings for this model"
|
||||||
|
)
|
||||||
|
"""The user-configured default settings for this model"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataStoreBase(ABC):
|
||||||
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||||
|
"""
|
||||||
|
Add a block of repo metadata to a model record.
|
||||||
|
|
||||||
|
The model record config must already exist in the database with the
|
||||||
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to store
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||||
|
"""Dump out all the metadata."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Update metadata corresponding to the model with the indicated key.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to update
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return all tags in the tags table."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||||
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_author(self, author: str) -> Set[str]:
|
||||||
|
"""Return the keys of models authored by the indicated author."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
||||||
223
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
223
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
SQL Storage for Model Metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||||
|
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||||
|
|
||||||
|
from .metadata_store_base import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||||
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase):
|
||||||
|
"""
|
||||||
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
|
:param conn: sqlite3 connection object
|
||||||
|
:param lock: threading Lock object
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._db = db
|
||||||
|
self._cursor = self._db.conn.cursor()
|
||||||
|
|
||||||
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||||
|
"""
|
||||||
|
Add a block of repo metadata to a model record.
|
||||||
|
|
||||||
|
The model record config must already exist in the database with the
|
||||||
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to store
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json()
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO model_metadata(
|
||||||
|
id,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
model_key,
|
||||||
|
json_serialized,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise UnknownMetadataException from excp
|
||||||
|
except sqlite3.Error as excp:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise excp
|
||||||
|
|
||||||
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT metadata FROM model_metadata
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchone()
|
||||||
|
if not rows:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
return ModelMetadataFetchBase.from_json(rows[0])
|
||||||
|
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||||
|
"""Dump out all the metadata."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id,metadata FROM model_metadata;
|
||||||
|
""",
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchall()
|
||||||
|
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||||
|
|
||||||
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Update metadata corresponding to the model with the indicated key.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to update
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE model_metadata
|
||||||
|
SET
|
||||||
|
metadata=?
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(json_serialized, model_key),
|
||||||
|
)
|
||||||
|
if self._cursor.rowcount == 0:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return self.get_metadata(model_key)
|
||||||
|
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return all tags in the tags table."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
select tag_text from tags;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||||
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
matches: Optional[Set[str]] = None
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT a.model_id FROM model_tags AS a,
|
||||||
|
tags AS b
|
||||||
|
WHERE a.tag_id=b.tag_id
|
||||||
|
AND b.tag_text=?;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||||
|
if matches is None:
|
||||||
|
matches = model_keys
|
||||||
|
matches = matches.intersection(model_keys)
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise e
|
||||||
|
return matches if matches else set()
|
||||||
|
|
||||||
|
def search_by_author(self, author: str) -> Set[str]:
|
||||||
|
"""Return the keys of models authored by the indicated author."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE author=?;
|
||||||
|
""",
|
||||||
|
(author,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE name=?;
|
||||||
|
""",
|
||||||
|
(name,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||||
|
"""Update tags for the model referenced by model_key."""
|
||||||
|
if tags:
|
||||||
|
# remove previous tags from this model
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM model_tags
|
||||||
|
WHERE model_id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO tags (
|
||||||
|
tag_text
|
||||||
|
)
|
||||||
|
VALUES (?);
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT tag_id
|
||||||
|
FROM tags
|
||||||
|
WHERE tag_text = ?
|
||||||
|
LIMIT 1;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
tag_id = self._cursor.fetchone()[0]
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO model_tags (
|
||||||
|
model_id,
|
||||||
|
tag_id
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(model_key, tag_id),
|
||||||
|
)
|
||||||
@@ -6,19 +6,20 @@ Abstract base class for storing and retrieving model configuration records.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@@ -59,33 +60,11 @@ class ModelSummary(BaseModel):
|
|||||||
tags: Set[str] = Field(description="tags associated with model")
|
tags: Set[str] = Field(description="tags associated with model")
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordChanges(BaseModelExcludeNull):
|
|
||||||
"""A set of changes to apply to a model."""
|
|
||||||
|
|
||||||
# Changes applicable to all models
|
|
||||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
|
||||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
|
||||||
description: Optional[str] = Field(description="Model description", default=None)
|
|
||||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
|
||||||
description="Default settings for this model", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Checkpoint-specific changes
|
|
||||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
|
||||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
|
||||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
|
||||||
description="The prediction type of the model.", default=None
|
|
||||||
)
|
|
||||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordServiceBase(ABC):
|
class ModelRecordServiceBase(ABC):
|
||||||
"""Abstract base class for storage and retrieval of model configs."""
|
"""Abstract base class for storage and retrieval of model configs."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
@@ -109,12 +88,13 @@ class ModelRecordServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Update the model, returning the updated version.
|
Update the model, returning the updated version.
|
||||||
|
|
||||||
:param key: Unique key for the model to be updated.
|
:param key: Unique key for the model to be updated
|
||||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
:param config: Model configuration record. Either a dict with the
|
||||||
|
required fields, or a ModelConfigBase instance.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -129,6 +109,40 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||||
|
"""Return a ModelMetadataStore initialized on the same database."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""
|
||||||
|
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||||
|
|
||||||
|
:param key: Model key
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||||
|
"""List metadata for all models that have it."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||||
|
"""
|
||||||
|
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||||
|
|
||||||
|
:param tags: Set of tags to search for. All tags must be present.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return a unique set of all the model tags in the metadata database."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
@@ -203,3 +217,21 @@ class ModelRecordServiceBase(ABC):
|
|||||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||||
)
|
)
|
||||||
return model_configs[0]
|
return model_configs[0]
|
||||||
|
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
new_name: str,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Rename the indicated model. Just a special case of update_model().
|
||||||
|
|
||||||
|
In some implementations, renaming the model may involve changing where
|
||||||
|
it is stored on the filesystem. So this is broken out.
|
||||||
|
|
||||||
|
:param key: Model key
|
||||||
|
:param new_name: New name for model
|
||||||
|
"""
|
||||||
|
config = self.get_model(key)
|
||||||
|
config.name = new_name
|
||||||
|
return self.update_model(key, config)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@@ -53,11 +53,12 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordChanges,
|
|
||||||
ModelRecordOrderBy,
|
ModelRecordOrderBy,
|
||||||
ModelRecordServiceBase,
|
ModelRecordServiceBase,
|
||||||
ModelSummary,
|
ModelSummary,
|
||||||
@@ -68,7 +69,7 @@ from .model_records_base import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
@@ -77,13 +78,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = db.conn.cursor()
|
self._cursor = db.conn.cursor()
|
||||||
|
self._metadata_store = metadata_store
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self) -> SqliteDatabase:
|
def db(self) -> SqliteDatabase:
|
||||||
"""Return the underlying database."""
|
"""Return the underlying database."""
|
||||||
return self._db
|
return self._db
|
||||||
|
|
||||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
@@ -93,19 +95,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
|
|
||||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||||
"""
|
"""
|
||||||
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||||
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO models (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
|
original_hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?);
|
VALUES (?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
config.key,
|
key,
|
||||||
config.model_dump_json(),
|
record.original_hash,
|
||||||
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._db.conn.commit()
|
self._db.conn.commit()
|
||||||
@@ -113,12 +119,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
except sqlite3.IntegrityError as e:
|
except sqlite3.IntegrityError as e:
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
if "models.path" in str(e):
|
if "model_config.path" in str(e):
|
||||||
msg = f"A model with path '{config.path}' is already installed"
|
msg = f"A model with path '{record.path}' is already installed"
|
||||||
elif "models.name" in str(e):
|
elif "model_config.name" in str(e):
|
||||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||||
else:
|
else:
|
||||||
msg = f"A model with key '{config.key}' is already installed"
|
msg = f"A model with key '{key}' is already installed"
|
||||||
raise DuplicateModelException(msg) from e
|
raise DuplicateModelException(msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
@@ -126,7 +132,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return self.get_model(config.key)
|
return self.get_model(key)
|
||||||
|
|
||||||
def del_model(self, key: str) -> None:
|
def del_model(self, key: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -140,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM models
|
DELETE FROM model_config
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@@ -152,20 +158,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
record = self.get_model(key)
|
"""
|
||||||
|
Update the model, returning the updated version.
|
||||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
|
||||||
for field_name in changes.model_fields_set:
|
|
||||||
setattr(record, field_name, getattr(changes, field_name))
|
|
||||||
|
|
||||||
json_serialized = record.model_dump_json()
|
|
||||||
|
|
||||||
|
:param key: Unique key for the model to be updated
|
||||||
|
:param config: Model configuration record. Either a dict with the
|
||||||
|
required fields, or a ModelConfigBase instance.
|
||||||
|
"""
|
||||||
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||||
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE models
|
UPDATE model_config
|
||||||
SET
|
SET
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
@@ -192,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM models
|
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@@ -213,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
select count(*) FROM models
|
select count(*) FROM model_config
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@@ -239,8 +246,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
If none of the optional filters are passed, will return all
|
If none of the optional filters are passed, will return all
|
||||||
models in the database.
|
models in the database.
|
||||||
"""
|
"""
|
||||||
where_clause: list[str] = []
|
results = []
|
||||||
bindings: list[str] = []
|
where_clause = []
|
||||||
|
bindings = []
|
||||||
if model_name:
|
if model_name:
|
||||||
where_clause.append("name=?")
|
where_clause.append("name=?")
|
||||||
bindings.append(model_name)
|
bindings.append(model_name)
|
||||||
@@ -257,13 +265,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM models
|
select config, strftime('%s',updated_at) FROM model_config
|
||||||
{where};
|
{where};
|
||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
results = [
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||||
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||||
@@ -272,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM models
|
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||||
WHERE path=?;
|
WHERE path=?;
|
||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
@@ -283,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||||
"""Return models with the indicated hash."""
|
"""Return models with the indicated original_hash."""
|
||||||
results = []
|
results = []
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM models
|
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||||
WHERE hash=?;
|
WHERE original_hash=?;
|
||||||
""",
|
""",
|
||||||
(hash,),
|
(hash,),
|
||||||
)
|
)
|
||||||
@@ -298,35 +307,83 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
]
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||||
|
"""Return a ModelMetadataStore initialized on the same database."""
|
||||||
|
return self._metadata_store
|
||||||
|
|
||||||
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""
|
||||||
|
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||||
|
|
||||||
|
:param key: Model key
|
||||||
|
"""
|
||||||
|
store = self.metadata_store
|
||||||
|
try:
|
||||||
|
metadata = store.get_metadata(key)
|
||||||
|
return metadata
|
||||||
|
except UnknownMetadataException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||||
|
"""
|
||||||
|
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||||
|
|
||||||
|
:param tags: Set of tags to search for. All tags must be present.
|
||||||
|
"""
|
||||||
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
|
keys = store.search_by_tag(tags)
|
||||||
|
return [self.get_model(x) for x in keys]
|
||||||
|
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return a unique set of all the model tags in the metadata database."""
|
||||||
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
|
return store.list_tags()
|
||||||
|
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||||
|
"""List metadata for all models that have it."""
|
||||||
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
|
return store.list_all_metadata()
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
) -> PaginatedResults[ModelSummary]:
|
) -> PaginatedResults[ModelSummary]:
|
||||||
"""Return a paginated summary listing of each model in the database."""
|
"""Return a paginated summary listing of each model in the database."""
|
||||||
assert isinstance(order_by, ModelRecordOrderBy)
|
|
||||||
ordering = {
|
ordering = {
|
||||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||||
ModelRecordOrderBy.Type: "type",
|
ModelRecordOrderBy.Type: "a.type",
|
||||||
ModelRecordOrderBy.Base: "base",
|
ModelRecordOrderBy.Base: "a.base",
|
||||||
ModelRecordOrderBy.Name: "name",
|
ModelRecordOrderBy.Name: "a.name",
|
||||||
ModelRecordOrderBy.Format: "format",
|
ModelRecordOrderBy.Format: "a.format",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||||
|
"""Fix up results so that there are no null values."""
|
||||||
|
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||||
|
for key, item in summary.items():
|
||||||
|
result[key] = item or ""
|
||||||
|
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||||
|
return result
|
||||||
|
|
||||||
# Lock so that the database isn't updated while we're doing the two queries.
|
# Lock so that the database isn't updated while we're doing the two queries.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
# query1: get the total number of model configs
|
# query1: get the total number of model configs
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
select count(*) from models;
|
select count(*) from model_config;
|
||||||
""",
|
""",
|
||||||
(),
|
(),
|
||||||
)
|
)
|
||||||
total = int(self._cursor.fetchone()[0])
|
total = int(self._cursor.fetchone()[0])
|
||||||
|
|
||||||
# query2: fetch key fields
|
# query2: fetch key fields from the join of model_config and model_metadata
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT config
|
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||||
FROM models
|
json_extract(a.config, '$.description') as description,
|
||||||
|
json_extract(b.metadata, '$.tags') as tags
|
||||||
|
FROM model_config AS a
|
||||||
|
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
OFFSET ?;
|
OFFSET ?;
|
||||||
@@ -337,7 +394,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
rows = self._cursor.fetchall()
|
rows = self._cursor.fetchall()
|
||||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||||
return PaginatedResults(
|
return PaginatedResults(
|
||||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,35 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from threading import Event
|
|
||||||
|
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
|
||||||
|
|
||||||
|
|
||||||
class SessionRunnerBase(ABC):
|
|
||||||
"""
|
|
||||||
Base class for session runner.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
|
||||||
"""Starts the session runner"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run(self, queue_item: SessionQueueItem) -> None:
|
|
||||||
"""Runs the session"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
|
||||||
"""Completes the session"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
|
||||||
"""Runs an already prepared node on the session"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SessionProcessorBase(ABC):
|
class SessionProcessorBase(ABC):
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import traceback
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from threading import BoundedSemaphore, Thread
|
from threading import BoundedSemaphore, Thread
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from typing import Callable, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
@@ -17,164 +16,15 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
|||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
from .session_processor_base import SessionProcessorBase
|
||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionRunner(SessionRunnerBase):
|
|
||||||
"""Processes a single session's invocations"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
|
||||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
|
||||||
):
|
|
||||||
self.on_before_run_node = on_before_run_node
|
|
||||||
self.on_after_run_node = on_after_run_node
|
|
||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
|
||||||
"""Start the session runner"""
|
|
||||||
self.services = services
|
|
||||||
self.cancel_event = cancel_event
|
|
||||||
|
|
||||||
def run(self, queue_item: SessionQueueItem):
|
|
||||||
"""Run the graph"""
|
|
||||||
if not queue_item.session:
|
|
||||||
raise ValueError("Queue item has no session")
|
|
||||||
# Loop over invocations until the session is complete or canceled
|
|
||||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
|
||||||
# Prepare the next node
|
|
||||||
invocation = queue_item.session.next()
|
|
||||||
if invocation is None:
|
|
||||||
# If there are no more invocations, complete the graph
|
|
||||||
break
|
|
||||||
# Build invocation context (the node-facing API
|
|
||||||
self.run_node(invocation.id, queue_item)
|
|
||||||
self.complete(queue_item)
|
|
||||||
|
|
||||||
def complete(self, queue_item: SessionQueueItem):
|
|
||||||
"""Complete the graph"""
|
|
||||||
self.services.events.emit_graph_execution_complete(
|
|
||||||
queue_batch_id=queue_item.batch_id,
|
|
||||||
queue_item_id=queue_item.item_id,
|
|
||||||
queue_id=queue_item.queue_id,
|
|
||||||
graph_execution_state_id=queue_item.session.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
|
||||||
"""Run before a node is executed"""
|
|
||||||
# Send starting event
|
|
||||||
self.services.events.emit_invocation_started(
|
|
||||||
queue_batch_id=queue_item.batch_id,
|
|
||||||
queue_item_id=queue_item.item_id,
|
|
||||||
queue_id=queue_item.queue_id,
|
|
||||||
graph_execution_state_id=queue_item.session_id,
|
|
||||||
node=invocation.model_dump(),
|
|
||||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
||||||
)
|
|
||||||
if self.on_before_run_node is not None:
|
|
||||||
self.on_before_run_node(invocation, queue_item)
|
|
||||||
|
|
||||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
|
||||||
"""Run after a node is executed"""
|
|
||||||
if self.on_after_run_node is not None:
|
|
||||||
self.on_after_run_node(invocation, queue_item)
|
|
||||||
|
|
||||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
|
||||||
"""Run a single node in the graph"""
|
|
||||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
|
||||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
|
||||||
try:
|
|
||||||
self._on_before_run_node(invocation, queue_item)
|
|
||||||
data = InvocationContextData(
|
|
||||||
invocation=invocation,
|
|
||||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
||||||
queue_item=queue_item,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
|
||||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
|
||||||
context = build_invocation_context(
|
|
||||||
data=data,
|
|
||||||
services=self.services,
|
|
||||||
cancel_event=self.cancel_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Invoke the node
|
|
||||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
|
||||||
|
|
||||||
# Save outputs and history
|
|
||||||
queue_item.session.complete(invocation.id, outputs)
|
|
||||||
|
|
||||||
self._on_after_run_node(invocation, queue_item)
|
|
||||||
# Send complete event on successful runs
|
|
||||||
self.services.events.emit_invocation_complete(
|
|
||||||
queue_batch_id=queue_item.batch_id,
|
|
||||||
queue_item_id=queue_item.item_id,
|
|
||||||
queue_id=queue_item.queue_id,
|
|
||||||
graph_execution_state_id=queue_item.session.id,
|
|
||||||
node=invocation.model_dump(),
|
|
||||||
source_node_id=data.source_invocation_id,
|
|
||||||
result=outputs.model_dump(),
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
# TODO(MM2): Create an event for this
|
|
||||||
pass
|
|
||||||
except CanceledException:
|
|
||||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
|
||||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
|
||||||
# be able to cancel them mid-execution.
|
|
||||||
#
|
|
||||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
|
||||||
# is executed after each step. This step callback checks if the canceled event is set,
|
|
||||||
# then raises a CanceledException to stop execution immediately.
|
|
||||||
#
|
|
||||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
|
||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
error = traceback.format_exc()
|
|
||||||
|
|
||||||
# Save error
|
|
||||||
queue_item.session.set_node_error(invocation.id, error)
|
|
||||||
self.services.logger.error(
|
|
||||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
|
||||||
)
|
|
||||||
self.services.logger.error(error)
|
|
||||||
|
|
||||||
# Send error event
|
|
||||||
self.services.events.emit_invocation_error(
|
|
||||||
queue_batch_id=queue_item.session_id,
|
|
||||||
queue_item_id=queue_item.item_id,
|
|
||||||
queue_id=queue_item.queue_id,
|
|
||||||
graph_execution_state_id=queue_item.session.id,
|
|
||||||
node=invocation.model_dump(),
|
|
||||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
"""Processes sessions from the session queue"""
|
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||||
|
|
||||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
|
||||||
|
|
||||||
def start(
|
|
||||||
self,
|
|
||||||
invoker: Invoker,
|
|
||||||
thread_limit: int = 1,
|
|
||||||
polling_interval: int = 1,
|
|
||||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
|
||||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
|
||||||
) -> None:
|
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
self.on_before_run_session = on_before_run_session
|
|
||||||
self.on_after_run_session = on_after_run_session
|
|
||||||
|
|
||||||
self._resume_event = ThreadEvent()
|
self._resume_event = ThreadEvent()
|
||||||
self._stop_event = ThreadEvent()
|
self._stop_event = ThreadEvent()
|
||||||
@@ -209,7 +59,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
"cancel_event": self._cancel_event,
|
"cancel_event": self._cancel_event,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
@@ -268,34 +117,131 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||||
cancel_event.clear()
|
cancel_event.clear()
|
||||||
|
|
||||||
# If we have a on_before_run_session callback, call it
|
|
||||||
if self.on_before_run_session is not None:
|
|
||||||
self.on_before_run_session(self._queue_item)
|
|
||||||
|
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||||
|
|
||||||
# Run the graph
|
# Prepare invocations and take the first
|
||||||
self.session_runner.run(queue_item=self._queue_item)
|
self._invocation = self._queue_item.session.next()
|
||||||
|
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# Loop over invocations until the session is complete or canceled
|
||||||
if self._profiler:
|
while self._invocation is not None and not cancel_event.is_set():
|
||||||
profile_path = self._profiler.stop()
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
stats_path = profile_path.with_suffix(".json")
|
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||||
self._invoker.services.performance_statistics.dump_stats(
|
|
||||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
# Send starting event
|
||||||
|
self._invoker.services.events.emit_invocation_started(
|
||||||
|
queue_batch_id=self._queue_item.batch_id,
|
||||||
|
queue_item_id=self._queue_item.item_id,
|
||||||
|
queue_id=self._queue_item.queue_id,
|
||||||
|
graph_execution_state_id=self._queue_item.session_id,
|
||||||
|
node=self._invocation.model_dump(),
|
||||||
|
source_node_id=source_invocation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||||
# we don't care about that - suppress the error.
|
try:
|
||||||
with suppress(GESStatsNotFoundError):
|
with self._invoker.services.performance_statistics.collect_stats(
|
||||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
self._invocation, self._queue_item.session.id
|
||||||
self._invoker.services.performance_statistics.reset_stats()
|
):
|
||||||
|
# Build invocation context (the node-facing API)
|
||||||
|
data = InvocationContextData(
|
||||||
|
invocation=self._invocation,
|
||||||
|
source_invocation_id=source_invocation_id,
|
||||||
|
queue_item=self._queue_item,
|
||||||
|
)
|
||||||
|
context = build_invocation_context(
|
||||||
|
data=data,
|
||||||
|
services=self._invoker.services,
|
||||||
|
cancel_event=self._cancel_event,
|
||||||
|
)
|
||||||
|
|
||||||
# If we have a on_after_run_session callback, call it
|
# Invoke the node
|
||||||
if self.on_after_run_session is not None:
|
outputs = self._invocation.invoke_internal(
|
||||||
self.on_after_run_session(self._queue_item)
|
context=context, services=self._invoker.services
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save outputs and history
|
||||||
|
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||||
|
|
||||||
|
# Send complete event
|
||||||
|
self._invoker.services.events.emit_invocation_complete(
|
||||||
|
queue_batch_id=self._queue_item.batch_id,
|
||||||
|
queue_item_id=self._queue_item.item_id,
|
||||||
|
queue_id=self._queue_item.queue_id,
|
||||||
|
graph_execution_state_id=self._queue_item.session.id,
|
||||||
|
node=self._invocation.model_dump(),
|
||||||
|
source_node_id=source_invocation_id,
|
||||||
|
result=outputs.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# TODO(MM2): Create an event for this
|
||||||
|
pass
|
||||||
|
|
||||||
|
except CanceledException:
|
||||||
|
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||||
|
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||||
|
# be able to cancel them mid-execution.
|
||||||
|
#
|
||||||
|
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||||
|
# is executed after each step. This step callback checks if the canceled event is set,
|
||||||
|
# then raises a CanceledException to stop execution immediately.
|
||||||
|
#
|
||||||
|
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||||
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error = traceback.format_exc()
|
||||||
|
|
||||||
|
# Save error
|
||||||
|
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||||
|
self._invoker.services.logger.error(
|
||||||
|
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||||
|
)
|
||||||
|
self._invoker.services.logger.error(error)
|
||||||
|
|
||||||
|
# Send error event
|
||||||
|
self._invoker.services.events.emit_invocation_error(
|
||||||
|
queue_batch_id=self._queue_item.session_id,
|
||||||
|
queue_item_id=self._queue_item.item_id,
|
||||||
|
queue_id=self._queue_item.queue_id,
|
||||||
|
graph_execution_state_id=self._queue_item.session.id,
|
||||||
|
node=self._invocation.model_dump(),
|
||||||
|
source_node_id=source_invocation_id,
|
||||||
|
error_type=e.__class__.__name__,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The session is complete if the all invocations are complete or there was an error
|
||||||
|
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||||
|
# Send complete event
|
||||||
|
self._invoker.services.events.emit_graph_execution_complete(
|
||||||
|
queue_batch_id=self._queue_item.batch_id,
|
||||||
|
queue_item_id=self._queue_item.item_id,
|
||||||
|
queue_id=self._queue_item.queue_id,
|
||||||
|
graph_execution_state_id=self._queue_item.session.id,
|
||||||
|
)
|
||||||
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
|
if self._profiler:
|
||||||
|
profile_path = self._profiler.stop()
|
||||||
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
self._invoker.services.performance_statistics.dump_stats(
|
||||||
|
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||||
|
)
|
||||||
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
|
# we don't care about that - suppress the error.
|
||||||
|
with suppress(GESStatsNotFoundError):
|
||||||
|
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||||
|
self._invoker.services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
|
# Set the invocation to None to prepare for the next session
|
||||||
|
self._invocation = None
|
||||||
|
else:
|
||||||
|
# Prepare the next invocation
|
||||||
|
self._invocation = self._queue_item.session.next()
|
||||||
|
|
||||||
# The session is complete, immediately poll for next session
|
# The session is complete, immediately poll for next session
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
@@ -329,4 +275,3 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
self._thread_semaphore.release()
|
self._thread_semaphore.release()
|
||||||
self._invoker.services.logger.debug("Session processor stopped")
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +35,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_4())
|
migrator.register_migration(build_migration_4())
|
||||||
migrator.register_migration(build_migration_5())
|
migrator.register_migration(build_migration_5())
|
||||||
migrator.register_migration(build_migration_6())
|
migrator.register_migration(build_migration_6())
|
||||||
migrator.register_migration(build_migration_7())
|
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
|
||||||
|
|
||||||
|
|
||||||
class Migration7Callback:
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
self._create_models_table(cursor)
|
|
||||||
self._drop_old_models_tables(cursor)
|
|
||||||
|
|
||||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
|
||||||
|
|
||||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
|
||||||
|
|
||||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""Creates the v4.0.0 models table."""
|
|
||||||
|
|
||||||
tables = [
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS models (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
|
||||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
|
||||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
|
||||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
|
||||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
|
||||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
|
||||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
|
||||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
|
||||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
|
||||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
|
||||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
|
||||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
|
||||||
config TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- unique constraint on combo of name, base and type
|
|
||||||
UNIQUE(name, base, type)
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
triggers = [
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON models FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE id = old.id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add indexes for searchable fields
|
|
||||||
indices = [
|
|
||||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
|
||||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
|
||||||
]
|
|
||||||
|
|
||||||
for stmt in tables + indices + triggers:
|
|
||||||
cursor.execute(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
def build_migration_7() -> Migration:
|
|
||||||
"""
|
|
||||||
Build the migration from database version 6 to 7.
|
|
||||||
|
|
||||||
This migration does the following:
|
|
||||||
- Adds the new models table
|
|
||||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
|
||||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
|
||||||
"""
|
|
||||||
migration_7 = Migration(
|
|
||||||
from_version=6,
|
|
||||||
to_version=7,
|
|
||||||
callback=Migration7Callback(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return migration_7
|
|
||||||
@@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.hash,
|
record.original_hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ class MigrateCallback(Protocol):
|
|||||||
See :class:`Migration` for an example.
|
See :class:`Migration` for an example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(RuntimeError):
|
class MigrationError(RuntimeError):
|
||||||
|
|||||||
55
invokeai/app/util/metadata.py
Normal file
55
invokeai/app/util/metadata.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.graph import Edge
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Parses raw session string, returning a dict of the graph.
|
||||||
|
|
||||||
|
Only the general graph shape is validated; none of the fields are validated.
|
||||||
|
|
||||||
|
Any `metadata_accumulator` nodes and edges are removed.
|
||||||
|
|
||||||
|
Any validation failure will return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = json.loads(session_raw).get("graph", None)
|
||||||
|
|
||||||
|
# sanity check make sure the graph is at least reasonably shaped
|
||||||
|
if (
|
||||||
|
not isinstance(graph, dict)
|
||||||
|
or "nodes" not in graph
|
||||||
|
or not isinstance(graph["nodes"], dict)
|
||||||
|
or "edges" not in graph
|
||||||
|
or not isinstance(graph["edges"], list)
|
||||||
|
):
|
||||||
|
# something has gone terribly awry, return an empty dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# delete the `metadata_accumulator` node
|
||||||
|
del graph["nodes"]["metadata_accumulator"]
|
||||||
|
except KeyError:
|
||||||
|
# no accumulator node, all good
|
||||||
|
pass
|
||||||
|
|
||||||
|
# delete any edges to or from it
|
||||||
|
for i, edge in enumerate(graph["edges"]):
|
||||||
|
try:
|
||||||
|
# try to parse the edge
|
||||||
|
Edge(**edge)
|
||||||
|
except ValidationError:
|
||||||
|
# something has gone terribly awry, return an empty dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (
|
||||||
|
edge["source"]["node_id"] == "metadata_accumulator"
|
||||||
|
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||||
|
):
|
||||||
|
del graph["edges"][i]
|
||||||
|
|
||||||
|
return graph
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
||||||
# and modified as needed
|
|
||||||
|
|
||||||
# tencent-ailab comment:
|
|
||||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
|
||||||
|
|
||||||
|
|
||||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
|
||||||
# loading.
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
|
||||||
ip_adapter_image_prompt_embeds parameter.
|
|
||||||
"""
|
|
||||||
return DiffusersAttnProcessor2_0.__call__(
|
|
||||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
scale (`float`, defaults to 1.0):
|
|
||||||
the weight scale of image prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
assert len(weights) == len(scales)
|
|
||||||
|
|
||||||
self._weights = weights
|
|
||||||
self._scales = scales
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Apply IP-Adapter attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
|
||||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
|
||||||
"""
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
|
||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
|
||||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
|
||||||
):
|
|
||||||
# The batch dimensions should match.
|
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
|
||||||
# The token_len dimensions should match.
|
|
||||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
|
||||||
|
|
||||||
ip_hidden_states = ipa_embed
|
|
||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
@@ -25,13 +25,10 @@ from enum import Enum
|
|||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers import ModelMixin
|
||||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
@@ -59,8 +56,8 @@ class ModelType(str, Enum):
|
|||||||
|
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
Main = "main"
|
Main = "main"
|
||||||
VAE = "vae"
|
Vae = "vae"
|
||||||
LoRA = "lora"
|
Lora = "lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
@@ -76,9 +73,9 @@ class SubModelType(str, Enum):
|
|||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
VAE = "vae"
|
Vae = "vae"
|
||||||
VAEDecoder = "vae_decoder"
|
VaeDecoder = "vae_decoder"
|
||||||
VAEEncoder = "vae_encoder"
|
VaeEncoder = "vae_encoder"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
|
|
||||||
@@ -96,8 +93,8 @@ class ModelFormat(str, Enum):
|
|||||||
|
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
LyCORIS = "lycoris"
|
Lycoris = "lycoris"
|
||||||
ONNX = "onnx"
|
Onnx = "onnx"
|
||||||
Olive = "olive"
|
Olive = "olive"
|
||||||
EmbeddingFile = "embedding_file"
|
EmbeddingFile = "embedding_file"
|
||||||
EmbeddingFolder = "embedding_folder"
|
EmbeddingFolder = "embedding_folder"
|
||||||
@@ -115,186 +112,127 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
OpenVINO = "openvino"
|
OPENVINO = "openvino"
|
||||||
Flax = "flax"
|
FLAX = "flax"
|
||||||
|
|
||||||
|
|
||||||
class ModelSourceType(str, Enum):
|
|
||||||
"""Model source type."""
|
|
||||||
|
|
||||||
Path = "path"
|
|
||||||
Url = "url"
|
|
||||||
HFRepoID = "hf_repo_id"
|
|
||||||
CivitAI = "civitai"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDefaultSettings(BaseModel):
|
|
||||||
vae: str | None
|
|
||||||
vae_precision: str | None
|
|
||||||
scheduler: SCHEDULER_NAME_VALUES | None
|
|
||||||
steps: int | None
|
|
||||||
cfg_scale: float | None
|
|
||||||
cfg_rescale_multiplier: float | None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
"""Base class for model configuration information."""
|
"""Base class for model configuration information."""
|
||||||
|
|
||||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
path: str = Field(description="filesystem path to the model file or directory")
|
||||||
hash: str = Field(description="The hash of the model file(s).")
|
name: str = Field(description="model name")
|
||||||
path: str = Field(
|
base: BaseModelType = Field(description="base model")
|
||||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
type: ModelType = Field(description="type of the model")
|
||||||
)
|
format: ModelFormat = Field(description="model format")
|
||||||
name: str = Field(description="Name of the model.")
|
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||||
base: BaseModelType = Field(description="The base model.")
|
original_hash: Optional[str] = Field(
|
||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description="original fasthash of model contents", default=None
|
||||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
) # this is assigned at install time and will not change
|
||||||
source_type: ModelSourceType = Field(description="The type of source")
|
current_hash: Optional[str] = Field(
|
||||||
source_api_response: Optional[str] = Field(
|
description="current fasthash of model contents", default=None
|
||||||
description="The original API response from the source, as stringified JSON.", default=None
|
) # if model is converted or otherwise modified, this will hold updated hash
|
||||||
)
|
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||||
description="Default settings for this model", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
schema["required"].extend(["key", "type", "format"])
|
schema["required"].extend(
|
||||||
|
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
model_config = ConfigDict(
|
||||||
|
use_enum_values=False,
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_extra=json_schema_extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, attributes: Dict[str, Any]) -> None:
|
||||||
|
"""Update the object with fields in dict."""
|
||||||
|
for key, value in attributes.items():
|
||||||
|
setattr(self, key, value) # may raise a validation error
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfigBase(ModelConfigBase):
|
class _CheckpointConfig(ModelConfigBase):
|
||||||
"""Model config for checkpoint-style models."""
|
"""Model config for checkpoint-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
config_path: str = Field(description="path to the checkpoint model config file")
|
config: str = Field(description="path to the checkpoint model config file")
|
||||||
converted_at: Optional[float] = Field(
|
|
||||||
description="When this model was last converted to diffusers", default_factory=time.time
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusersConfigBase(ModelConfigBase):
|
class _DiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for diffusers-style models."""
|
"""Model config for diffusers-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
class LoRALyCORISConfig(ModelConfigBase):
|
class LoRAConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRADiffusersConfig(ModelConfigBase):
|
class VaeCheckpointConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Diffusers models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class VAECheckpointConfig(CheckpointConfigBase):
|
|
||||||
"""Model config for standalone VAE models."""
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
|
||||||
|
|
||||||
|
class VaeDiffusersConfig(ModelConfigBase):
|
||||||
class VAEDiffusersConfig(ModelConfigBase):
|
|
||||||
"""Model config for standalone VAE models (diffusers version)."""
|
"""Model config for standalone VAE models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
|
||||||
|
|
||||||
|
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
|
||||||
|
|
||||||
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
|
||||||
|
|
||||||
|
class TextualInversionConfig(ModelConfigBase):
|
||||||
class TextualInversionFileConfig(ModelConfigBase):
|
|
||||||
"""Model config for textual inversion embeddings."""
|
"""Model config for textual inversion embeddings."""
|
||||||
|
|
||||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderConfig(ModelConfigBase):
|
class _MainConfig(ModelConfigBase):
|
||||||
"""Model config for textual inversion embeddings."""
|
"""Model config for main models."""
|
||||||
|
|
||||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
vae: Optional[str] = Field(default=None)
|
||||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfigBase):
|
|
||||||
"""Model config for main checkpoint models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
ztsnr_training: bool = False
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfigBase):
|
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterConfig(ModelConfigBase):
|
||||||
@@ -304,10 +242,6 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
image_encoder_model_id: str
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for ClipVision."""
|
"""Model config for ClipVision."""
|
||||||
@@ -315,65 +249,58 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
|||||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
|
||||||
|
|
||||||
|
class T2IConfig(ModelConfigBase):
|
||||||
class T2IAdapterConfig(ModelConfigBase):
|
|
||||||
"""Model config for T2I."""
|
"""Model config for T2I."""
|
||||||
|
|
||||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tag() -> Tag:
|
|
||||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
|
||||||
|
|
||||||
|
_ControlNetConfig = Annotated[
|
||||||
|
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||||
|
Field(discriminator="format"),
|
||||||
|
]
|
||||||
|
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||||
|
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||||
|
|
||||||
def get_model_discriminator_value(v: Any) -> str:
|
AnyModelConfig = Union[
|
||||||
"""
|
_MainModelConfig,
|
||||||
Computes the discriminator value for a model config.
|
_VaeConfig,
|
||||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
_ControlNetConfig,
|
||||||
"""
|
# ModelConfigBase,
|
||||||
format_ = None
|
LoRAConfig,
|
||||||
type_ = None
|
TextualInversionConfig,
|
||||||
if isinstance(v, dict):
|
IPAdapterConfig,
|
||||||
format_ = v.get("format")
|
CLIPVisionDiffusersConfig,
|
||||||
if isinstance(format_, Enum):
|
T2IConfig,
|
||||||
format_ = format_.value
|
|
||||||
type_ = v.get("type")
|
|
||||||
if isinstance(type_, Enum):
|
|
||||||
type_ = type_.value
|
|
||||||
else:
|
|
||||||
format_ = v.format.value
|
|
||||||
type_ = v.type.value
|
|
||||||
v = f"{type_}.{format_}"
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
AnyModelConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
|
||||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
|
||||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
|
||||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
|
||||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
|
||||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
|
||||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
|
||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
|
||||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
|
||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
|
||||||
],
|
|
||||||
Discriminator(get_model_discriminator_value),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
# IMPLEMENTATION NOTE:
|
||||||
|
# The preferred alternative to the above is a discriminated Union as shown
|
||||||
|
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||||
|
# This is a known issue. Please see:
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||||
|
# AnyModelConfig = Annotated[
|
||||||
|
# Union[
|
||||||
|
# _MainModelConfig,
|
||||||
|
# _ONNXConfig,
|
||||||
|
# _VaeConfig,
|
||||||
|
# _ControlNetConfig,
|
||||||
|
# LoRAConfig,
|
||||||
|
# TextualInversionConfig,
|
||||||
|
# IPAdapterConfig,
|
||||||
|
# CLIPVisionDiffusersConfig,
|
||||||
|
# T2IConfig,
|
||||||
|
# ],
|
||||||
|
# Field(discriminator="type"),
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigFactory(object):
|
class ModelConfigFactory(object):
|
||||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||||
|
|
||||||
@@ -405,6 +332,6 @@ class ModelConfigFactory(object):
|
|||||||
assert model is not None
|
assert model is not None
|
||||||
if key:
|
if key:
|
||||||
model.key = key
|
model.key = key
|
||||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
if timestamp:
|
||||||
model.converted_at = timestamp
|
model.last_modified = timestamp
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
@@ -51,7 +50,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
:param submodel_type: an ModelType enum indicating the portion of
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
"""
|
"""
|
||||||
if model_config.type is ModelType.Main and not submodel_type:
|
if model_config.type == "main" and not submodel_type:
|
||||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||||
|
|
||||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||||
@@ -81,7 +80,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||||
return self._convert_model(config, model_path, cache_path)
|
return self._convert_model(config, model_path, cache_path)
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _load_if_needed(
|
def _load_if_needed(
|
||||||
@@ -120,7 +119,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
return calc_model_size_by_fs(
|
return calc_model_size_by_fs(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
subfolder=submodel_type.value if submodel_type else None,
|
subfolder=submodel_type.value if submodel_type else None,
|
||||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This needs to be implemented in subclasses that handle checkpoints
|
# This needs to be implemented in subclasses that handle checkpoints
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ Use like this:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from ..config import (
|
from ..config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -25,6 +27,8 @@ from ..config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
|
VaeCheckpointConfig,
|
||||||
|
VaeDiffusersConfig,
|
||||||
)
|
)
|
||||||
from . import ModelLoaderBase
|
from . import ModelLoaderBase
|
||||||
|
|
||||||
@@ -57,9 +61,6 @@ class ModelLoaderRegistryBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderRegistry:
|
class ModelLoaderRegistry:
|
||||||
"""
|
"""
|
||||||
This class allows model loaders to register their type, base and format.
|
This class allows model loaders to register their type, base and format.
|
||||||
@@ -70,10 +71,10 @@ class ModelLoaderRegistry:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def register(
|
def register(
|
||||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||||
"""Define a decorator which registers the subclass of loader."""
|
"""Define a decorator which registers the subclass of loader."""
|
||||||
|
|
||||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||||
key = cls._to_registry_key(base, type, format)
|
key = cls._to_registry_key(base, type, format)
|
||||||
if key in cls._registry:
|
if key in cls._registry:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -89,15 +90,33 @@ class ModelLoaderRegistry:
|
|||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
|
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||||
|
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||||
|
|
||||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||||
if not implementation:
|
if not implementation:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||||
)
|
)
|
||||||
return implementation, config, submodel_type
|
return implementation, conf2, submodel_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_subtype_overrides(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||||
|
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||||
|
model_path = Path(config.vae)
|
||||||
|
config_class = (
|
||||||
|
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||||
|
)
|
||||||
|
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||||
|
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||||
|
submodel_type = None
|
||||||
|
else:
|
||||||
|
new_conf = config
|
||||||
|
return new_conf, submodel_type
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
@@ -21,15 +20,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
|||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||||
class ControlNetLoader(GenericDiffusersLoader):
|
class ControlnetLoader(GenericDiffusersLoader):
|
||||||
"""Class to load ControlNet models."""
|
"""Class to load ControlNet models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if config.format != ModelFormat.Checkpoint:
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
@@ -38,13 +37,13 @@ class ControlNetLoader(GenericDiffusersLoader):
|
|||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||||
else:
|
else:
|
||||||
assert isinstance(config, CheckpointConfigBase)
|
assert hasattr(config, "config")
|
||||||
config_file = config.config_path
|
config_file = config.config
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
checkpoint = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
@@ -42,7 +41,6 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||||
result = None
|
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
try:
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||||
@@ -66,7 +64,6 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
assert result is not None
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
@@ -78,7 +75,7 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
result: ModelMixin = getattr(res_type, class_name)
|
result: ModelMixin = getattr(res_type, class_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||||
|
|
||||||
|
|
||||||
@@ -86,8 +83,8 @@ class ConfigLoader(ConfigMixin):
|
|||||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
"""Load a diffusrs ConfigMixin configuration."""
|
"""Load a diffusrs ConfigMixin configuration."""
|
||||||
cls.config_name = kwargs.pop("config_name")
|
cls.config_name = kwargs.pop("config_name")
|
||||||
# TODO(psyche): the types on this diffusers method are not correct
|
# Diffusers doesn't provide typing info
|
||||||
return super().load_config(*args, **kwargs) # type: ignore
|
return super().load_config(*args, **kwargs) # type: ignore
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
if submodel_type is not None:
|
if submodel_type is not None:
|
||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model = build_ip_adapter(
|
model = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
from .. import ModelLoader, ModelLoaderRegistry
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||||
class LoraLoader(ModelLoader):
|
class LoraLoader(ModelLoader):
|
||||||
"""Class to load LoRA models."""
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
|||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load onnx models."""
|
"""Class to load onnx models."""
|
||||||
|
|||||||
@@ -4,8 +4,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
@@ -17,7 +16,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
@@ -55,11 +54,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if config.format != ModelFormat.Checkpoint:
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
@@ -74,7 +73,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||||
)
|
)
|
||||||
|
|
||||||
config_file = config.config_path
|
config_file = config.config
|
||||||
|
|
||||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -13,25 +13,24 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
class VaeLoader(GenericDiffusersLoader):
|
class VaeLoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if config.format != ModelFormat.Checkpoint:
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
@@ -39,15 +38,16 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
# TODO(MM2): check whether sdxl VAE models convert.
|
# TO DO: check whether sdxl VAE models convert.
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||||
else:
|
else:
|
||||||
assert isinstance(config, CheckpointConfigBase)
|
config_file = (
|
||||||
config_file = config.config_path
|
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
checkpoint = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||||
assert isinstance(ckpt_config, DictConfig)
|
assert isinstance(ckpt_config, DictConfig)
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from diffusers import AutoPipelineForText2Image
|
|||||||
from diffusers.utils import logging as dlogging
|
from diffusers.utils import logging as dlogging
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
@@ -118,6 +117,7 @@ class ModelMerger(object):
|
|||||||
config = self._installer.app_config
|
config = self._installer.app_config
|
||||||
store = self._installer.record_store
|
store = self._installer.record_store
|
||||||
base_models: Set[BaseModelType] = set()
|
base_models: Set[BaseModelType] = set()
|
||||||
|
vae = None
|
||||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -134,6 +134,10 @@ class ModelMerger(object):
|
|||||||
"normal"
|
"normal"
|
||||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||||
|
|
||||||
|
# pick up the first model's vae
|
||||||
|
if key == model_keys[0]:
|
||||||
|
vae = info.vae
|
||||||
|
|
||||||
# tally base models used
|
# tally base models used
|
||||||
base_models.add(info.base)
|
base_models.add(info.base)
|
||||||
model_paths.extend([config.models_path / info.path])
|
model_paths.extend([config.models_path / info.path])
|
||||||
@@ -159,10 +163,12 @@ class ModelMerger(object):
|
|||||||
|
|
||||||
# update model's config
|
# update model's config
|
||||||
model_config = self._installer.record_store.get_model(key)
|
model_config = self._installer.record_store.get_model(key)
|
||||||
model_config.name = merged_model_name
|
model_config.update(
|
||||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
{
|
||||||
|
"name": merged_model_name,
|
||||||
self._installer.record_store.update_model(
|
"description": f"Merge of models {', '.join(model_names)}",
|
||||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
"vae": vae,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
self._installer.record_store.update_model(key, model_config)
|
||||||
return model_config
|
return model_config
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ from .metadata_base import (
|
|||||||
AnyModelRepoMetadataValidator,
|
AnyModelRepoMetadataValidator,
|
||||||
BaseMetadata,
|
BaseMetadata,
|
||||||
CivitaiMetadata,
|
CivitaiMetadata,
|
||||||
|
CommercialUsage,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
|
LicenseRestrictions,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
UnknownMetadataException,
|
UnknownMetadataException,
|
||||||
@@ -36,8 +38,10 @@ __all__ = [
|
|||||||
"AnyModelRepoMetadataValidator",
|
"AnyModelRepoMetadataValidator",
|
||||||
"CivitaiMetadata",
|
"CivitaiMetadata",
|
||||||
"CivitaiMetadataFetch",
|
"CivitaiMetadataFetch",
|
||||||
|
"CommercialUsage",
|
||||||
"HuggingFaceMetadata",
|
"HuggingFaceMetadata",
|
||||||
"HuggingFaceMetadataFetch",
|
"HuggingFaceMetadataFetch",
|
||||||
|
"LicenseRestrictions",
|
||||||
"ModelMetadataFetchBase",
|
"ModelMetadataFetchBase",
|
||||||
"BaseMetadata",
|
"BaseMetadata",
|
||||||
"ModelMetadataWithFiles",
|
"ModelMetadataWithFiles",
|
||||||
|
|||||||
@@ -23,21 +23,22 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
|||||||
print(metadata.trained_words)
|
print(metadata.trained_words)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import TypeAdapter, ValidationError
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import (
|
from ..metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
CivitaiMetadata,
|
CivitaiMetadata,
|
||||||
|
CommercialUsage,
|
||||||
|
LicenseRestrictions,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
UnknownMetadataException,
|
UnknownMetadataException,
|
||||||
)
|
)
|
||||||
@@ -51,13 +52,10 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
|||||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||||
|
|
||||||
|
|
||||||
StringSetAdapter = TypeAdapter(set[str])
|
|
||||||
|
|
||||||
|
|
||||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||||
"""Fetch model metadata from Civitai."""
|
"""Fetch model metadata from Civitai."""
|
||||||
|
|
||||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
def __init__(self, session: Optional[Session] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||||
|
|
||||||
@@ -65,7 +63,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
this module without an internet connection.
|
this module without an internet connection.
|
||||||
"""
|
"""
|
||||||
self._requests = session or requests.Session()
|
self._requests = session or requests.Session()
|
||||||
self._api_key = api_key
|
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
@@ -105,21 +102,22 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
May raise an `UnknownMetadataException`.
|
May raise an `UnknownMetadataException`.
|
||||||
"""
|
"""
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
model_json = self._requests.get(model_url).json()
|
||||||
return self._from_api_response(model_json)
|
return self._from_model_json(model_json)
|
||||||
|
|
||||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
try:
|
try:
|
||||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||||
except TypeError as excp:
|
except TypeError as excp:
|
||||||
raise UnknownMetadataException from excp
|
raise UnknownMetadataException from excp
|
||||||
|
|
||||||
# loop till we find the section containing the version requested
|
# loop till we find the section containing the version requested
|
||||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||||
if not version_sections:
|
if not version_sections:
|
||||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||||
|
|
||||||
version_json = version_sections[0]
|
version_json = version_sections[0]
|
||||||
|
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||||
|
|
||||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||||
@@ -136,23 +134,36 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||||
model_files = [
|
model_files = [
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=self._get_url_with_api_key(url),
|
url=url,
|
||||||
path=Path(primary_file["name"]),
|
path=Path(primary_file["name"]),
|
||||||
size=int(primary_file["sizeKB"] * 1024),
|
size=int(primary_file["sizeKB"] * 1024),
|
||||||
sha256=primary_file["hashes"]["SHA256"],
|
sha256=primary_file["hashes"]["SHA256"],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
|
||||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
|
||||||
except ValidationError:
|
|
||||||
trigger_phrases: set[str] = set()
|
|
||||||
|
|
||||||
return CivitaiMetadata(
|
return CivitaiMetadata(
|
||||||
|
id=model_json["id"],
|
||||||
name=version_json["name"],
|
name=version_json["name"],
|
||||||
|
version_id=version_json["id"],
|
||||||
|
version_name=version_json["name"],
|
||||||
|
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||||
|
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||||
|
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||||
|
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||||
files=model_files,
|
files=model_files,
|
||||||
trigger_phrases=trigger_phrases,
|
download_url=version_json["downloadUrl"],
|
||||||
api_response=json.dumps(version_json),
|
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||||
|
author=model_json["creator"]["username"],
|
||||||
|
description=model_json["description"],
|
||||||
|
version_description=version_json["description"] or "",
|
||||||
|
tags=model_json["tags"],
|
||||||
|
trained_words=version_json["trainedWords"],
|
||||||
|
nsfw=model_json["nsfw"],
|
||||||
|
restrictions=LicenseRestrictions(
|
||||||
|
AllowNoCredit=model_json["allowNoCredit"],
|
||||||
|
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||||
|
AllowDerivatives=model_json["allowDerivatives"],
|
||||||
|
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
@@ -163,14 +174,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
"""
|
"""
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
version = self._requests.get(version_url).json()
|
||||||
if error := version.get("error"):
|
if error := version.get("error"):
|
||||||
raise UnknownMetadataException(error)
|
raise UnknownMetadataException(error)
|
||||||
model_id = version["modelId"]
|
model_id = version["modelId"]
|
||||||
|
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
model_json = self._requests.get(model_url).json()
|
||||||
return self._from_api_response(model_json, version_id)
|
return self._from_model_json(model_json, version_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||||
@@ -178,11 +189,6 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
metadata = CivitaiMetadata.model_validate_json(json)
|
metadata = CivitaiMetadata.model_validate_json(json)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def _get_url_with_api_key(self, url: str) -> str:
|
|
||||||
if not self._api_key:
|
|
||||||
return url
|
|
||||||
|
|
||||||
if "?" in url:
|
def _fix_timezone(date: str) -> str:
|
||||||
return f"{url}&token={self._api_key}"
|
return re.sub(r"Z$", "+00:00", date)
|
||||||
|
|
||||||
return f"{url}?token={self._api_key}"
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
|||||||
print(metadata.tags)
|
print(metadata.tags)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -24,7 +23,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
|||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import (
|
from ..metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
@@ -61,7 +60,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||||
# If not available, then set variant to None and get the default.
|
# If not available, then set variant to None and get the default.
|
||||||
# If this too fails, raise exception.
|
# If this too fails, raise exception.
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
while not model_info:
|
while not model_info:
|
||||||
try:
|
try:
|
||||||
@@ -74,24 +72,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
else:
|
else:
|
||||||
variant = None
|
variant = None
|
||||||
|
|
||||||
files: list[RemoteModelFile] = []
|
|
||||||
|
|
||||||
_, name = id.split("/")
|
_, name = id.split("/")
|
||||||
|
|
||||||
for s in model_info.siblings or []:
|
|
||||||
assert s.rfilename is not None
|
|
||||||
assert s.size is not None
|
|
||||||
files.append(
|
|
||||||
RemoteModelFile(
|
|
||||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
|
||||||
path=Path(name, s.rfilename),
|
|
||||||
size=s.size,
|
|
||||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return HuggingFaceMetadata(
|
return HuggingFaceMetadata(
|
||||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
id=model_info.id,
|
||||||
|
author=model_info.author,
|
||||||
|
name=name,
|
||||||
|
last_modified=model_info.last_modified,
|
||||||
|
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||||
|
tags=model_info.tags,
|
||||||
|
files=[
|
||||||
|
RemoteModelFile(
|
||||||
|
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||||
|
path=Path(name, x.rfilename),
|
||||||
|
size=x.size,
|
||||||
|
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||||
|
)
|
||||||
|
for x in model_info.siblings
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ versions of these fields are intended to be kept in sync with the
|
|||||||
remote repo.
|
remote repo.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
@@ -23,6 +25,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.backend.model_manager import ModelRepoVariant
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..util import select_hf_files
|
from ..util import select_hf_files
|
||||||
@@ -32,6 +35,31 @@ class UnknownMetadataException(Exception):
|
|||||||
"""Raised when no metadata is available for a model."""
|
"""Raised when no metadata is available for a model."""
|
||||||
|
|
||||||
|
|
||||||
|
class CommercialUsage(str, Enum):
|
||||||
|
"""Type of commercial usage allowed."""
|
||||||
|
|
||||||
|
No = "None"
|
||||||
|
Image = "Image"
|
||||||
|
Rent = "Rent"
|
||||||
|
RentCivit = "RentCivit"
|
||||||
|
Sell = "Sell"
|
||||||
|
|
||||||
|
|
||||||
|
class LicenseRestrictions(BaseModel):
|
||||||
|
"""Broad categories of licensing restrictions."""
|
||||||
|
|
||||||
|
AllowNoCredit: bool = Field(
|
||||||
|
description="if true, model can be redistributed without crediting author", default=False
|
||||||
|
)
|
||||||
|
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||||
|
AllowDifferentLicense: bool = Field(
|
||||||
|
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||||
|
)
|
||||||
|
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||||
|
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RemoteModelFile(BaseModel):
|
class RemoteModelFile(BaseModel):
|
||||||
"""Information about a downloadable file that forms part of a model."""
|
"""Information about a downloadable file that forms part of a model."""
|
||||||
|
|
||||||
@@ -41,10 +69,24 @@ class RemoteModelFile(BaseModel):
|
|||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDefaultSettings(BaseModel):
|
||||||
|
vae: str | None
|
||||||
|
vae_precision: str | None
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES | None
|
||||||
|
steps: int | None
|
||||||
|
cfg_scale: float | None
|
||||||
|
cfg_rescale_multiplier: float | None
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataBase(BaseModel):
|
class ModelMetadataBase(BaseModel):
|
||||||
"""Base class for model metadata information."""
|
"""Base class for model metadata information."""
|
||||||
|
|
||||||
name: str = Field(description="model's name")
|
name: str = Field(description="model's name")
|
||||||
|
author: str = Field(description="model's author")
|
||||||
|
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||||
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
|
description="default settings for this model", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseMetadata(ModelMetadataBase):
|
class BaseMetadata(ModelMetadataBase):
|
||||||
@@ -82,16 +124,60 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
|||||||
"""Extended metadata fields provided by Civitai."""
|
"""Extended metadata fields provided by Civitai."""
|
||||||
|
|
||||||
type: Literal["civitai"] = "civitai"
|
type: Literal["civitai"] = "civitai"
|
||||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
id: int = Field(description="Civitai version identifier")
|
||||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||||
|
version_id: int = Field(description="Civitai model version identifier")
|
||||||
|
created: datetime = Field(description="date the model was created")
|
||||||
|
updated: datetime = Field(description="date the model was last modified")
|
||||||
|
published: datetime = Field(description="date the model was published to Civitai")
|
||||||
|
description: str = Field(description="text description of model; may contain HTML")
|
||||||
|
version_description: str = Field(
|
||||||
|
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||||
|
)
|
||||||
|
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||||
|
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||||
|
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||||
|
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||||
|
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||||
|
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||||
|
weight_minmax: Tuple[float, float] = Field(
|
||||||
|
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||||
|
) # note: For future use
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credit_required(self) -> bool:
|
||||||
|
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||||
|
return not self.restrictions.AllowNoCredit
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_commercial_use(self) -> bool:
|
||||||
|
"""Return True if commercial use is allowed."""
|
||||||
|
if self.restrictions.AllowCommercialUse is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# accommodate schema change
|
||||||
|
acu = self.restrictions.AllowCommercialUse
|
||||||
|
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||||
|
return CommercialUsage.No not in commercial_usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_derivatives(self) -> bool:
|
||||||
|
"""Return True if derivatives of this model can be redistributed."""
|
||||||
|
return self.restrictions.AllowDerivatives
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_different_license(self) -> bool:
|
||||||
|
"""Return true if derivatives of this model can use a different license."""
|
||||||
|
return self.restrictions.AllowDifferentLicense
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||||
"""Extended metadata fields provided by HuggingFace."""
|
"""Extended metadata fields provided by HuggingFace."""
|
||||||
|
|
||||||
type: Literal["huggingface"] = "huggingface"
|
type: Literal["huggingface"] = "huggingface"
|
||||||
id: str = Field(description="The HF model id")
|
id: str = Field(description="huggingface model id")
|
||||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
tag_dict: Dict[str, Any]
|
||||||
|
last_modified: datetime = Field(description="date of last commit to repo")
|
||||||
|
|
||||||
def download_urls(
|
def download_urls(
|
||||||
self,
|
self,
|
||||||
@@ -120,7 +206,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
# the next step reads model_index.json to determine which subdirectories belong
|
# the next step reads model_index.json to determine which subdirectories belong
|
||||||
# to the model
|
# to the model
|
||||||
if Path(f"{prefix}model_index.json") in paths:
|
if Path(f"{prefix}model_index.json") in paths:
|
||||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||||
resp = session.get(url)
|
resp = session.get(url)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
submodels = resp.json()
|
submodels = resp.json()
|
||||||
|
|||||||
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
SQL Storage for Model Metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
|
||||||
|
from .fetch import ModelMetadataFetchBase
|
||||||
|
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataStore:
|
||||||
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase):
|
||||||
|
"""
|
||||||
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
|
:param conn: sqlite3 connection object
|
||||||
|
:param lock: threading Lock object
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._db = db
|
||||||
|
self._cursor = self._db.conn.cursor()
|
||||||
|
|
||||||
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||||
|
"""
|
||||||
|
Add a block of repo metadata to a model record.
|
||||||
|
|
||||||
|
The model record config must already exist in the database with the
|
||||||
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to store
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json()
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO model_metadata(
|
||||||
|
id,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
model_key,
|
||||||
|
json_serialized,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise UnknownMetadataException from excp
|
||||||
|
except sqlite3.Error as excp:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise excp
|
||||||
|
|
||||||
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT metadata FROM model_metadata
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchone()
|
||||||
|
if not rows:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
return ModelMetadataFetchBase.from_json(rows[0])
|
||||||
|
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||||
|
"""Dump out all the metadata."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id,metadata FROM model_metadata;
|
||||||
|
""",
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchall()
|
||||||
|
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||||
|
|
||||||
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Update metadata corresponding to the model with the indicated key.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to update
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE model_metadata
|
||||||
|
SET
|
||||||
|
metadata=?
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(json_serialized, model_key),
|
||||||
|
)
|
||||||
|
if self._cursor.rowcount == 0:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return self.get_metadata(model_key)
|
||||||
|
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return all tags in the tags table."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
select tag_text from tags;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||||
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
matches: Optional[Set[str]] = None
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT a.model_id FROM model_tags AS a,
|
||||||
|
tags AS b
|
||||||
|
WHERE a.tag_id=b.tag_id
|
||||||
|
AND b.tag_text=?;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||||
|
if matches is None:
|
||||||
|
matches = model_keys
|
||||||
|
matches = matches.intersection(model_keys)
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise e
|
||||||
|
return matches if matches else set()
|
||||||
|
|
||||||
|
def search_by_author(self, author: str) -> Set[str]:
|
||||||
|
"""Return the keys of models authored by the indicated author."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE author=?;
|
||||||
|
""",
|
||||||
|
(author,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE name=?;
|
||||||
|
""",
|
||||||
|
(name,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||||
|
"""Update tags for the model referenced by model_key."""
|
||||||
|
# remove previous tags from this model
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM model_tags
|
||||||
|
WHERE model_id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO tags (
|
||||||
|
tag_text
|
||||||
|
)
|
||||||
|
VALUES (?);
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT tag_id
|
||||||
|
FROM tags
|
||||||
|
WHERE tag_text = ?
|
||||||
|
LIMIT 1;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
tag_id = self._cursor.fetchone()[0]
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO model_tags (
|
||||||
|
model_id,
|
||||||
|
tag_id
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(model_key, tag_id),
|
||||||
|
)
|
||||||
@@ -8,7 +8,6 @@ import torch
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@@ -18,7 +17,6 @@ from .config import (
|
|||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelSourceType,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
@@ -97,8 +95,8 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.VAE,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
"AutoencoderTiny": ModelType.VAE,
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
"T2IAdapter": ModelType.T2IAdapter,
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
@@ -110,6 +108,14 @@ class ModelProbe(object):
|
|||||||
) -> None:
|
) -> None:
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def heuristic_probe(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
return cls.probe(model_path, fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe(
|
def probe(
|
||||||
cls,
|
cls,
|
||||||
@@ -131,21 +137,19 @@ class ModelProbe(object):
|
|||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = None
|
||||||
if format_type is ModelFormat.Diffusers:
|
if format_type == "diffusers":
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
else:
|
else:
|
||||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
|
hash = ModelHash().hash(model_path)
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
|
||||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
|
||||||
fields["key"] = fields.get("key", uuid_string())
|
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
fields["type"] = fields.get("type") or model_type
|
fields["type"] = fields.get("type") or model_type
|
||||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||||
@@ -157,17 +161,15 @@ class ModelProbe(object):
|
|||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["original_hash"] = fields.get("original_hash") or hash
|
||||||
|
fields["current_hash"] = fields.get("current_hash") or hash
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if (
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
fields["config"] = cls._get_checkpoint_config_path(
|
||||||
and fields["format"] is ModelFormat.Checkpoint
|
|
||||||
):
|
|
||||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
|
||||||
model_path,
|
model_path,
|
||||||
model_type=fields["type"],
|
model_type=fields["type"],
|
||||||
base_type=fields["base"],
|
base_type=fields["base"],
|
||||||
@@ -177,7 +179,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
# additional fields needed for main non-checkpoint models
|
# additional fields needed for main non-checkpoint models
|
||||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||||
ModelFormat.ONNX,
|
ModelFormat.Onnx,
|
||||||
ModelFormat.Olive,
|
ModelFormat.Olive,
|
||||||
ModelFormat.Diffusers,
|
ModelFormat.Diffusers,
|
||||||
]:
|
]:
|
||||||
@@ -211,11 +213,11 @@ class ModelProbe(object):
|
|||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
return ModelType.VAE
|
return ModelType.Vae
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
return ModelType.LoRA
|
return ModelType.Lora
|
||||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
return ModelType.LoRA
|
return ModelType.Lora
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
@@ -237,7 +239,7 @@ class ModelProbe(object):
|
|||||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||||
return ModelType.LoRA
|
return ModelType.Lora
|
||||||
if (folder_path / "unet/model.onnx").exists():
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
return ModelType.ONNX
|
return ModelType.ONNX
|
||||||
if (folder_path / "image_encoder.txt").exists():
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
@@ -283,21 +285,13 @@ class ModelProbe(object):
|
|||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
return possible_conf.absolute()
|
return possible_conf.absolute()
|
||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type == ModelType.Main:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
config_file = config_file[prediction_type]
|
config_file = config_file[prediction_type]
|
||||||
elif model_type is ModelType.ControlNet:
|
elif model_type == ModelType.ControlNet:
|
||||||
config_file = (
|
config_file = (
|
||||||
"../controlnet/cldm_v15.yaml"
|
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||||
if base_type is BaseModelType.StableDiffusion1
|
|
||||||
else "../controlnet/cldm_v21.yaml"
|
|
||||||
)
|
|
||||||
elif model_type is ModelType.VAE:
|
|
||||||
config_file = (
|
|
||||||
"../stable-diffusion/v1-inference.yaml"
|
|
||||||
if base_type is BaseModelType.StableDiffusion1
|
|
||||||
else "../stable-diffusion/v2-inference.yaml"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException(
|
raise InvalidModelConfigException(
|
||||||
@@ -503,12 +497,12 @@ class FolderProbeBase(ProbeBase):
|
|||||||
if ".fp16" in x.suffixes:
|
if ".fp16" in x.suffixes:
|
||||||
return ModelRepoVariant.FP16
|
return ModelRepoVariant.FP16
|
||||||
if "openvino_model" in x.name:
|
if "openvino_model" in x.name:
|
||||||
return ModelRepoVariant.OpenVINO
|
return ModelRepoVariant.OPENVINO
|
||||||
if "flax_model" in x.name:
|
if "flax_model" in x.name:
|
||||||
return ModelRepoVariant.Flax
|
return ModelRepoVariant.FLAX
|
||||||
if x.suffix == ".onnx":
|
if x.suffix == ".onnx":
|
||||||
return ModelRepoVariant.ONNX
|
return ModelRepoVariant.ONNX
|
||||||
return ModelRepoVariant.Default
|
return ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
@@ -714,8 +708,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
@@ -723,8 +717,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
|||||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
@@ -35,7 +34,7 @@ def filter_files(
|
|||||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||||
"""
|
"""
|
||||||
variant = variant or ModelRepoVariant.Default
|
variant = variant or ModelRepoVariant.DEFAULT
|
||||||
paths: List[Path] = []
|
paths: List[Path] = []
|
||||||
root = files[0].parts[0]
|
root = files[0].parts[0]
|
||||||
|
|
||||||
@@ -74,81 +73,64 @@ def filter_files(
|
|||||||
return sorted(_filter_by_variant(paths, variant))
|
return sorted(_filter_by_variant(paths, variant))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SubfolderCandidate:
|
|
||||||
path: Path
|
|
||||||
score: int
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||||
result: set[Path] = set()
|
result = set()
|
||||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
basenames: Dict[Path, Path] = {}
|
||||||
for path in files:
|
for path in files:
|
||||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||||
if variant == ModelRepoVariant.ONNX:
|
if variant == ModelRepoVariant.ONNX:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif "openvino_model" in path.name:
|
elif "openvino_model" in path.name:
|
||||||
if variant == ModelRepoVariant.OpenVINO:
|
if variant == ModelRepoVariant.OPENVINO:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif "flax_model" in path.name:
|
elif "flax_model" in path.name:
|
||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.FLAX:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
elif path.suffix in [".json", ".txt"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif variant in [
|
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||||
ModelRepoVariant.FP16,
|
ModelRepoVariant.FP16,
|
||||||
ModelRepoVariant.FP32,
|
ModelRepoVariant.FP32,
|
||||||
ModelRepoVariant.Default,
|
ModelRepoVariant.DEFAULT,
|
||||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
]:
|
||||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
|
||||||
# text encoders:
|
|
||||||
#
|
|
||||||
# - text_encoder/model.fp16.safetensors
|
|
||||||
# - text_encoder/model.safetensors
|
|
||||||
# - text_encoder/pytorch_model.bin
|
|
||||||
# - text_encoder/pytorch_model.fp16.bin
|
|
||||||
#
|
|
||||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
|
||||||
# variant and format and select the best one.
|
|
||||||
|
|
||||||
parent = path.parent
|
parent = path.parent
|
||||||
score = 0
|
suffixes = path.suffixes
|
||||||
|
if len(suffixes) == 2:
|
||||||
|
variant_label, suffix = suffixes
|
||||||
|
basename = parent / Path(path.stem).stem
|
||||||
|
else:
|
||||||
|
variant_label = ""
|
||||||
|
suffix = suffixes[0]
|
||||||
|
basename = parent / path.stem
|
||||||
|
|
||||||
if path.suffix == ".safetensors":
|
if previous := basenames.get(basename):
|
||||||
score += 1
|
if (
|
||||||
|
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
): # replace non-safetensors with safetensors when available
|
||||||
|
basenames[basename] = path
|
||||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
if variant_label == f".{variant}":
|
||||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
basenames[basename] = path
|
||||||
if candidate_variant_label == f".{variant}" or (
|
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
basenames[basename] = path
|
||||||
):
|
else:
|
||||||
score += 1
|
basenames[basename] = path
|
||||||
|
|
||||||
if parent not in subfolder_weights:
|
|
||||||
subfolder_weights[parent] = []
|
|
||||||
|
|
||||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for v in basenames.values():
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
result.add(v)
|
||||||
if highest_score_candidate:
|
|
||||||
result.add(highest_score_candidate.path)
|
|
||||||
|
|
||||||
# If one of the architecture-related variants was specified and no files matched other than
|
# If one of the architecture-related variants was specified and no files matched other than
|
||||||
# config and text files then we return an empty list
|
# config and text files then we return an empty list
|
||||||
if (
|
if (
|
||||||
variant
|
variant
|
||||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||||
and not any(variant.value in x.name for x in result)
|
and not any(variant.value in x.name for x in result)
|
||||||
):
|
):
|
||||||
return set()
|
return set()
|
||||||
|
|||||||
@@ -23,9 +23,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
IPAdapterConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
|
|
||||||
@@ -170,10 +173,11 @@ class ControlNetData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterData:
|
class IPAdapterData:
|
||||||
ip_adapter_model: IPAdapter = Field(default=None)
|
ip_adapter_model: IPAdapter
|
||||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
|
|
||||||
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
weight: Union[float, List[float]] = Field(default=1.0)
|
weight: Union[float, List[float]] = Field(default=1.0)
|
||||||
# weight: float = Field(default=1.0)
|
|
||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
@@ -314,7 +318,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
|
conditioning_data: TextConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@@ -374,6 +379,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@@ -393,7 +399,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
@@ -410,22 +417,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
use_cross_attention_control = (
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
|
use_regional_prompting = (
|
||||||
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_ip_adapter:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_regional_prompting:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_attention_patcher = None
|
||||||
|
self.use_ip_adapter = use_ip_adapter
|
||||||
|
attn_ctx = nullcontext()
|
||||||
|
if use_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
if use_ip_adapter or use_regional_prompting:
|
||||||
elif ip_adapter_data is not None:
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
self.use_ip_adapter = True
|
|
||||||
else:
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@@ -448,22 +468,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
unet_attention_patcher=unet_attention_patcher,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
|
||||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
|
||||||
latents=latents,
|
|
||||||
sigma=batched_t,
|
|
||||||
step_index=i,
|
|
||||||
total_step_count=len(timesteps),
|
|
||||||
)
|
|
||||||
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@@ -485,14 +497,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@@ -515,10 +528,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
unet_attention_patcher.set_scale(i, weight)
|
||||||
else:
|
else:
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
unet_attention_patcher.set_scale(i, 0.0)
|
||||||
|
|
||||||
# Handle ControlNet(s)
|
# Handle ControlNet(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
@@ -562,12 +575,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
down_intrablock_additional_residuals = accum_adapter_state
|
down_intrablock_additional_residuals = accum_adapter_state
|
||||||
|
|
||||||
|
ip_adapter_conditioning = None
|
||||||
|
if ip_adapter_data is not None:
|
||||||
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||||
@@ -587,7 +605,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: issue to diffusers?
|
# TODO: issue to diffusers?
|
||||||
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import dataclasses
|
from dataclasses import dataclass
|
||||||
import inspect
|
from typing import List, Optional, Union
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -10,6 +8,11 @@ from .cross_attention_control import Arguments
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
|
"""Extra conditioning information produced by Compel.
|
||||||
|
|
||||||
|
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||||
|
"""
|
||||||
|
|
||||||
tokens_count_including_eos_bos: int
|
tokens_count_including_eos_bos: int
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
@@ -20,6 +23,8 @@ class ExtraConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
|
"""SD 1/2 text conditioning information produced by Compel."""
|
||||||
|
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
|
|
||||||
@@ -35,6 +40,8 @@ class ConditioningFieldData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
"""SDXL text conditioning information produced by Compel."""
|
||||||
|
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
@@ -44,14 +51,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
return super().to(device=device, dtype=dtype)
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PostprocessingSettings:
|
|
||||||
threshold: float
|
|
||||||
warmup: float
|
|
||||||
h_symmetry_time_pct: Optional[float]
|
|
||||||
v_symmetry_time_pct: Optional[float]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterConditioningInfo:
|
class IPAdapterConditioningInfo:
|
||||||
cond_image_prompt_embeds: torch.Tensor
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
@@ -65,41 +64,55 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class Range:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
start: int
|
||||||
text_embeddings: BasicConditioningInfo
|
end: int
|
||||||
"""
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
||||||
"""
|
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
|
||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
|
||||||
"""
|
|
||||||
guidance_rescale_multiplier: float = 0
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
"""
|
|
||||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
|
||||||
"""
|
|
||||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
|
||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
|
||||||
|
|
||||||
@property
|
class TextConditioningRegions:
|
||||||
def dtype(self):
|
def __init__(
|
||||||
return self.text_embeddings.dtype
|
self,
|
||||||
|
masks: torch.Tensor,
|
||||||
|
ranges: list[Range],
|
||||||
|
mask_weights: list[float],
|
||||||
|
):
|
||||||
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
|
# Shape: (1, num_prompts, height, width)
|
||||||
|
# Dtype: torch.bool
|
||||||
|
self.masks = masks
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||||
scheduler_args = dict(self.scheduler_args)
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
step_method = inspect.signature(scheduler.step)
|
self.ranges = ranges
|
||||||
for name, value in kwargs.items():
|
|
||||||
try:
|
self.mask_weights = mask_weights
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
|
||||||
# FIXME: don't silently discard arguments
|
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
||||||
else:
|
class TextConditioningData:
|
||||||
scheduler_args[name] = value
|
def __init__(
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
self,
|
||||||
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
|
guidance_scale: Union[float, List[float]],
|
||||||
|
guidance_rescale_multiplier: float = 0,
|
||||||
|
):
|
||||||
|
self.uncond_text = uncond_text
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.uncond_regions = uncond_regions
|
||||||
|
self.cond_regions = cond_regions
|
||||||
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
|
def is_sdxl(self):
|
||||||
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
|||||||
242
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
242
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
|
|
||||||
|
This implementation is based on
|
||||||
|
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||||
|
|
||||||
|
Supported custom features:
|
||||||
|
- IP-Adapter
|
||||||
|
- Regional prompt attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||||
|
ip_adapter_scales: Optional[list[float]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
|
|
||||||
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
|
layer-specific are passed to __init__().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||||
|
for the i'th IP-Adapter.
|
||||||
|
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||||
|
the i'th IP-Adapter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._ip_adapter_weights = ip_adapter_weights
|
||||||
|
self._ip_adapter_scales = ip_adapter_scales
|
||||||
|
|
||||||
|
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||||
|
if self._ip_adapter_weights is not None:
|
||||||
|
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||||
|
|
||||||
|
def _is_ip_adapter_enabled(self) -> bool:
|
||||||
|
return self._ip_adapter_weights is not None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
# For regional prompting:
|
||||||
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
|
percent_through: Optional[float] = None,
|
||||||
|
# For IP-Adapter:
|
||||||
|
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Apply attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||||
|
apply regional prompt masking.
|
||||||
|
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||||
|
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||||
|
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
|
"""
|
||||||
|
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
residual = hidden_states
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Handle regional prompt attention masks.
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
assert percent_through is not None
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
|
if is_cross_attention:
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
|
)
|
||||||
|
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||||
|
dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
|
||||||
|
else: # self-attention
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||||
|
query_seq_len=query_seq_len,
|
||||||
|
percent_through=percent_through,
|
||||||
|
)
|
||||||
|
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
args = () if USE_PEFT_BACKEND else (scale,)
|
||||||
|
query = attn.to_q(hidden_states, *args)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states, *args)
|
||||||
|
value = attn.to_v(encoder_hidden_states, *args)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if regional_prompt_data is not None and percent_through < 0.3:
|
||||||
|
# Don't apply to uncond????
|
||||||
|
|
||||||
|
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||||
|
prompt_region_attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
|
||||||
|
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
|
||||||
|
|
||||||
|
prompt_region_attention_mask = attn_mask_weight * (
|
||||||
|
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = prompt_region_attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Apply IP-Adapter conditioning.
|
||||||
|
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||||
|
if self._is_ip_adapter_enabled():
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
|
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||||
|
):
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The token_len dimensions should match.
|
||||||
|
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||||
|
|
||||||
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
else:
|
||||||
|
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||||
|
assert ip_adapter_image_prompt_embeds is None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegionalPromptData:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
regions: list[TextConditioningRegions],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
|
"""Initialize a `RegionalPromptData` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
|
batch.
|
||||||
|
device (torch.device): The device to use for the attention masks.
|
||||||
|
dtype (torch.dtype): The data type to use for the attention masks.
|
||||||
|
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||||
|
in steps of 2x.
|
||||||
|
"""
|
||||||
|
self._regions = regions
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||||
|
# sequence length of s.
|
||||||
|
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||||
|
regions, max_downscale_factor
|
||||||
|
)
|
||||||
|
self._negative_cross_attn_mask_score = 0.0
|
||||||
|
self._size_weight = 1.0
|
||||||
|
|
||||||
|
def _prepare_spatial_masks(
|
||||||
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
|
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||||
|
# of s.
|
||||||
|
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
|
||||||
|
for batch_sample_regions in regions:
|
||||||
|
batch_sample_masks_by_seq_len.append({})
|
||||||
|
|
||||||
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
|
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||||
|
assert b == 1
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||||
|
# regions to be lost entirely.
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||||
|
# potentially use a weighted mask rather than a binary mask.
|
||||||
|
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return batch_sample_masks_by_seq_len
|
||||||
|
|
||||||
|
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the cross-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||||
|
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, key_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
|
||||||
|
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||||
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
|
# size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||||
|
# batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||||
|
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||||
|
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||||
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
|
||||||
|
mask_weight + self._size_weight * (1 - size)
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
|
||||||
|
"""Get the self-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, query_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx in range(num_prompts):
|
||||||
|
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||||
|
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||||
|
size = size.to(dtype=prompt_query_mask.dtype)
|
||||||
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
|
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||||
|
# query_seq_len) mask.
|
||||||
|
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
|
||||||
|
attn_mask[batch_idx, :, :] = torch.maximum(
|
||||||
|
attn_mask[batch_idx, :, :],
|
||||||
|
prompt_query_mask.unsqueeze(0)
|
||||||
|
* prompt_query_mask.unsqueeze(1)
|
||||||
|
* (mask_weight + self._size_weight * (1 - size)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# if attn_mask[batch_idx].max() < 0.01:
|
||||||
|
# attn_mask[batch_idx, ...] = 1.0
|
||||||
|
|
||||||
|
# attn_mask[attn_mask > 0.5] = 1.0
|
||||||
|
# attn_mask[attn_mask <= 0.5] = 0.0
|
||||||
|
# attn_mask_min = attn_mask[batch_idx].min()
|
||||||
|
|
||||||
|
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||||
|
# if abs(attn_mask_min) > 0.0001:
|
||||||
|
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
||||||
|
return attn_mask
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
@@ -10,11 +11,13 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ConditioningData,
|
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
PostprocessingSettings,
|
IPAdapterConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
Range,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@@ -56,7 +59,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
@@ -91,7 +93,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
conditioning_data,
|
conditioning_data: TextConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
@@ -124,38 +126,30 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
|
||||||
],
|
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(
|
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||||
encoder_hidden_states,
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask,
|
|
||||||
) = self._concat_conditionings_for_batch(
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
|
||||||
conditioning_data.text_embeddings.embeds,
|
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@@ -199,16 +193,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = (
|
||||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
)
|
)
|
||||||
@@ -224,6 +219,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
|
percent_through=percent_through,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
@@ -237,6 +234,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
percent_through=percent_through,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@@ -244,19 +243,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def do_latent_postprocessing(
|
|
||||||
self,
|
|
||||||
postprocessing_settings: PostprocessingSettings,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
step_index,
|
|
||||||
total_step_count,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if postprocessing_settings is not None:
|
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
conditioning_attention_mask = torch.ones(
|
conditioning_attention_mask = torch.ones(
|
||||||
@@ -304,13 +290,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
|
percent_through: float,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@@ -321,41 +307,55 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||||
torch.stack(
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
]
|
||||||
)
|
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
uncond_text = conditioning_data.uncond_text
|
||||||
]
|
cond_text = conditioning_data.cond_text
|
||||||
}
|
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||||
[
|
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
"time_ids": torch.cat(
|
|
||||||
[
|
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
uncond_text.embeds, cond_text.embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||||
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
mask_weights=[0.0],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
@@ -374,8 +374,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||||
|
percent_through: float,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@@ -422,36 +424,40 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
]
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
if conditioning_data.is_sdxl():
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
|
if conditioning_data.uncond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
@@ -463,36 +469,41 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Conditioned pass
|
# Conditioned pass
|
||||||
###################
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
]
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the conditioned pass.
|
||||||
|
if conditioning_data.cond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
@@ -506,64 +517,3 @@ class InvokeAIDiffuserComponent:
|
|||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
def apply_symmetry(
|
|
||||||
self,
|
|
||||||
postprocessing_settings: PostprocessingSettings,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
percent_through: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Reset our last percent through if this is our first step.
|
|
||||||
if percent_through == 0.0:
|
|
||||||
self.last_percent_through = 0.0
|
|
||||||
|
|
||||||
if postprocessing_settings is None:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
# Check for out of bounds
|
|
||||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
|
||||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
|
||||||
h_symmetry_time_pct = None
|
|
||||||
|
|
||||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
|
||||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
|
||||||
v_symmetry_time_pct = None
|
|
||||||
|
|
||||||
dev = latents.device.type
|
|
||||||
|
|
||||||
latents.to(device="cpu")
|
|
||||||
|
|
||||||
if (
|
|
||||||
h_symmetry_time_pct is not None
|
|
||||||
and self.last_percent_through < h_symmetry_time_pct
|
|
||||||
and percent_through >= h_symmetry_time_pct
|
|
||||||
):
|
|
||||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
|
||||||
width = latents.shape[3]
|
|
||||||
x_flipped = torch.flip(latents, dims=[3])
|
|
||||||
latents = torch.cat(
|
|
||||||
[
|
|
||||||
latents[:, :, :, 0 : int(width / 2)],
|
|
||||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
|
||||||
],
|
|
||||||
dim=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
v_symmetry_time_pct is not None
|
|
||||||
and self.last_percent_through < v_symmetry_time_pct
|
|
||||||
and percent_through >= v_symmetry_time_pct
|
|
||||||
):
|
|
||||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
|
||||||
height = latents.shape[2]
|
|
||||||
y_flipped = torch.flip(latents, dims=[2])
|
|
||||||
latents = torch.cat(
|
|
||||||
[
|
|
||||||
latents[:, :, 0 : int(height / 2)],
|
|
||||||
y_flipped[:, :, int(height / 2) : int(height)],
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_percent_through = percent_through
|
|
||||||
return latents.to(device=dev)
|
|
||||||
|
|||||||
@@ -1,52 +1,55 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
|
||||||
|
|
||||||
|
|
||||||
class UNetPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapters
|
||||||
self._scales = [1.0] * len(self._ip_adapters)
|
self._ip_adapter_scales = None
|
||||||
|
|
||||||
|
if self._ip_adapters is not None:
|
||||||
|
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||||
|
|
||||||
def set_scale(self, idx: int, value: float):
|
def set_scale(self, idx: int, value: float):
|
||||||
self._scales[idx] = value
|
self._ip_adapter_scales[idx] = value
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
weights into them.
|
weights into them (if IP-Adapters are being applied).
|
||||||
|
|
||||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||||
"""
|
"""
|
||||||
# Construct a dict of attention processors based on the UNet's architecture.
|
# Construct a dict of attention processors based on the UNet's architecture.
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||||
if name.endswith("attn1.processor"):
|
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||||
attn_procs[name] = AttnProcessor2_0()
|
# "attn1" processors do not use IP-Adapters.
|
||||||
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0(
|
attn_procs[name] = CustomAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||||
self._scales,
|
self._ip_adapter_scales,
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
attn_procs = self._prepare_attention_processors(unet)
|
attn_procs = self._prepare_attention_processors(unet)
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
@@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
|||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||||
index_no_updates
|
orig_embeds_params[index_no_updates]
|
||||||
] = orig_embeds_params[index_no_updates]
|
)
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.lora_models = self.add_model_widgets(
|
self.lora_models = self.add_model_widgets(
|
||||||
model_type=ModelType.LoRA,
|
model_type=ModelType.Lora,
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
)
|
)
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@
|
|||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
"lint:tsc": "tsc --noEmit",
|
"lint:tsc": "tsc --noEmit",
|
||||||
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
||||||
"fix": "eslint --fix . && prettier --log-level warn --write .",
|
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
|
||||||
"preinstall": "npx only-allow pnpm",
|
"preinstall": "npx only-allow pnpm",
|
||||||
"storybook": "storybook dev -p 6006",
|
"storybook": "storybook dev -p 6006",
|
||||||
"build-storybook": "storybook build",
|
"build-storybook": "storybook build",
|
||||||
|
|||||||
@@ -304,12 +304,6 @@
|
|||||||
"method": "High Resolution Fix Method"
|
"method": "High Resolution Fix Method"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"prompt": {
|
|
||||||
"addPromptTrigger": "Add Prompt Trigger",
|
|
||||||
"compatibleEmbeddings": "Compatible Embeddings",
|
|
||||||
"noPromptTriggers": "No triggers available",
|
|
||||||
"noMatchingTriggers": "No matching triggers"
|
|
||||||
},
|
|
||||||
"embedding": {
|
"embedding": {
|
||||||
"addEmbedding": "Add Embedding",
|
"addEmbedding": "Add Embedding",
|
||||||
"incompatibleModel": "Incompatible base model:",
|
"incompatibleModel": "Incompatible base model:",
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ addFirstListImagesListener(startAppListening);
|
|||||||
// Ad-hoc upscale workflwo
|
// Ad-hoc upscale workflwo
|
||||||
addUpscaleRequestedListener(startAppListening);
|
addUpscaleRequestedListener(startAppListening);
|
||||||
|
|
||||||
// Prompts
|
// Dynamic prompts
|
||||||
addDynamicPromptsListener(startAppListening);
|
addDynamicPromptsListener(startAppListening);
|
||||||
|
|
||||||
addSetDefaultSettingsListener(startAppListening);
|
addSetDefaultSettingsListener(startAppListening);
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ import {
|
|||||||
selectAllT2IAdapters,
|
selectAllT2IAdapters,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
|
|
||||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
|
||||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||||
@@ -26,9 +24,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||||
|
|
||||||
const state = getState();
|
const currentModel = getState().generation.model;
|
||||||
|
|
||||||
const currentModel = state.generation.model;
|
|
||||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||||
|
|
||||||
if (models.length === 0) {
|
if (models.length === 0) {
|
||||||
@@ -43,29 +39,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultModel = state.config.sd.defaultModel;
|
|
||||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
|
||||||
|
|
||||||
if (defaultModelInList) {
|
|
||||||
const result = zParameterModel.safeParse(defaultModelInList);
|
|
||||||
if (result.success) {
|
|
||||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
|
||||||
|
|
||||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
|
||||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const { width, height } = calculateNewSize(
|
|
||||||
state.generation.aspectRatio.value,
|
|
||||||
optimalDimension * optimalDimension
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(widthChanged(width));
|
|
||||||
dispatch(heightChanged(height));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = zParameterModel.safeParse(models[0]);
|
const result = zParameterModel.safeParse(models[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
|||||||
@@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||||
|
|
||||||
if (!modelConfig || !modelConfig.default_settings) {
|
if (!metadata || !metadata.default_settings) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
// we store this as "default" within default settings
|
// we store this as "default" within default settings
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { bytes, total_bytes, id } = action.payload.data;
|
const { bytes, total_bytes, id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.bytes = bytes;
|
modelImport.bytes = bytes;
|
||||||
@@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.status = 'completed';
|
modelImport.status = 'completed';
|
||||||
@@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
return draft;
|
return draft;
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
dispatch(api.util.invalidateTags(['Model']));
|
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { id, error, error_type } = action.payload.data;
|
const { id, error, error_type } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.status = 'error';
|
modelImport.status = 'error';
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ type Props = {
|
|||||||
onOpen: () => void;
|
onOpen: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
export const AddEmbeddingButton = memo((props: Props) => {
|
||||||
const { onOpen, isOpen } = props;
|
const { onOpen, isOpen } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
<Tooltip label={t('embedding.addEmbedding')}>
|
||||||
<IconButton
|
<IconButton
|
||||||
variant="promptOverlay"
|
variant="promptOverlay"
|
||||||
isDisabled={isOpen}
|
isDisabled={isOpen}
|
||||||
aria-label={t('prompt.addPromptTrigger')}
|
aria-label={t('embedding.addEmbedding')}
|
||||||
icon={<PiCodeBold />}
|
icon={<PiCodeBold />}
|
||||||
onClick={onOpen}
|
onClick={onOpen}
|
||||||
/>
|
/>
|
||||||
@@ -24,4 +24,4 @@ export const AddPromptTriggerButton = memo((props: Props) => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||||
const { onSelect, isOpen, onClose, width, children } = props;
|
const { onSelect, isOpen, onClose, width, children } = props;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -14,7 +14,7 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
|
|||||||
openDelay={0}
|
openDelay={0}
|
||||||
closeDelay={0}
|
closeDelay={0}
|
||||||
closeOnBlur={true}
|
closeOnBlur={true}
|
||||||
returnFocusOnClose={false}
|
returnFocusOnClose={true}
|
||||||
isLazy
|
isLazy
|
||||||
>
|
>
|
||||||
<PopoverAnchor>{children}</PopoverAnchor>
|
<PopoverAnchor>{children}</PopoverAnchor>
|
||||||
@@ -27,11 +27,11 @@ export const PromptPopover = memo((props: PromptPopoverProps) => {
|
|||||||
borderStyle="solid"
|
borderStyle="solid"
|
||||||
>
|
>
|
||||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||||
</PopoverBody>
|
</PopoverBody>
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
</Popover>
|
</Popover>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
PromptPopover.displayName = 'PromptPopover';
|
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
import type { Meta, StoryObj } from '@storybook/react';
|
||||||
|
|
||||||
|
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||||
|
import type { EmbeddingSelectProps } from './types';
|
||||||
|
|
||||||
|
const meta: Meta<typeof EmbeddingSelect> = {
|
||||||
|
title: 'Feature/Prompt/EmbeddingSelect',
|
||||||
|
tags: ['autodocs'],
|
||||||
|
component: EmbeddingSelect,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default meta;
|
||||||
|
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||||
|
|
||||||
|
const Component = (props: EmbeddingSelectProps) => {
|
||||||
|
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const Default: Story = {
|
||||||
|
render: Component,
|
||||||
|
};
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||||
|
|
||||||
|
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
|
const getIsDisabled = useCallback(
|
||||||
|
(embedding: TextualInversionModelConfig): boolean => {
|
||||||
|
const isCompatible = currentBaseModel === embedding.base;
|
||||||
|
const hasMainModel = Boolean(currentBaseModel);
|
||||||
|
return !hasMainModel || !isCompatible;
|
||||||
|
},
|
||||||
|
[currentBaseModel]
|
||||||
|
);
|
||||||
|
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||||
|
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(embedding: TextualInversionModelConfig | null) => {
|
||||||
|
if (!embedding) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onSelect(embedding.name);
|
||||||
|
},
|
||||||
|
[onSelect]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { options, onChange } = useGroupedModelCombobox({
|
||||||
|
modelEntities: data,
|
||||||
|
getIsDisabled,
|
||||||
|
onChange: _onChange,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<Combobox
|
||||||
|
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||||
|
defaultMenuIsOpen
|
||||||
|
autoFocus
|
||||||
|
value={null}
|
||||||
|
options={options}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
onChange={onChange}
|
||||||
|
onMenuClose={onClose}
|
||||||
|
data-testid="add-embedding"
|
||||||
|
sx={selectStyles}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||||
|
|
||||||
|
const selectStyles: ChakraProps['sx'] = {
|
||||||
|
w: 'full',
|
||||||
|
};
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
import type { PropsWithChildren } from 'react';
|
import type { PropsWithChildren } from 'react';
|
||||||
|
|
||||||
export type PromptTriggerSelectProps = {
|
export type EmbeddingSelectProps = {
|
||||||
onSelect: (v: string) => void;
|
onSelect: (v: string) => void;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type PromptPopoverProps = PropsWithChildren &
|
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||||
PromptTriggerSelectProps & {
|
EmbeddingSelectProps & {
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
width?: number | string;
|
width?: number | string;
|
||||||
};
|
};
|
||||||
@@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { flushSync } from 'react-dom';
|
import { flushSync } from 'react-dom';
|
||||||
|
|
||||||
type UseInsertTriggerArg = {
|
type UseInsertEmbeddingArg = {
|
||||||
prompt: string;
|
prompt: string;
|
||||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||||
onChange: (v: string) => void;
|
onChange: (v: string) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||||
|
|
||||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||||
@@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
|||||||
[_onChange]
|
[_onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const insertTrigger = useCallback(
|
const insertEmbedding = useCallback(
|
||||||
(v: string) => {
|
(v: string) => {
|
||||||
if (!textareaRef.current) {
|
if (!textareaRef.current) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is where we insert the trigger
|
// this is where we insert the TI trigger
|
||||||
const caret = textareaRef.current.selectionStart;
|
const caret = textareaRef.current.selectionStart;
|
||||||
|
|
||||||
if (isNil(caret)) {
|
if (isNil(caret)) {
|
||||||
@@ -35,9 +35,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
|||||||
|
|
||||||
let newPrompt = prompt.slice(0, caret);
|
let newPrompt = prompt.slice(0, caret);
|
||||||
|
|
||||||
newPrompt += `${v}`;
|
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||||
|
newPrompt += '<';
|
||||||
|
}
|
||||||
|
|
||||||
// we insert the cursor after the end of trigger
|
newPrompt += `${v}>`;
|
||||||
|
|
||||||
|
// we insert the cursor after the `>`
|
||||||
const finalCaretPos = newPrompt.length;
|
const finalCaretPos = newPrompt.length;
|
||||||
|
|
||||||
newPrompt += prompt.slice(caret);
|
newPrompt += prompt.slice(caret);
|
||||||
@@ -47,7 +51,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
|||||||
_onChange(newPrompt);
|
_onChange(newPrompt);
|
||||||
});
|
});
|
||||||
|
|
||||||
// set the cursor position to just after the trigger
|
// set the caret position to just after the TI trigger
|
||||||
textareaRef.current.selectionStart = finalCaretPos;
|
textareaRef.current.selectionStart = finalCaretPos;
|
||||||
textareaRef.current.selectionEnd = finalCaretPos;
|
textareaRef.current.selectionEnd = finalCaretPos;
|
||||||
},
|
},
|
||||||
@@ -58,17 +62,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
|||||||
textareaRef.current?.focus();
|
textareaRef.current?.focus();
|
||||||
}, [textareaRef]);
|
}, [textareaRef]);
|
||||||
|
|
||||||
const handleClosePopover = useCallback(() => {
|
const handleClose = useCallback(() => {
|
||||||
onClose();
|
onClose();
|
||||||
onFocus();
|
onFocus();
|
||||||
}, [onFocus, onClose]);
|
}, [onFocus, onClose]);
|
||||||
|
|
||||||
const onSelect = useCallback(
|
const onSelectEmbedding = useCallback(
|
||||||
(v: string) => {
|
(v: string) => {
|
||||||
insertTrigger(v);
|
insertEmbedding(v);
|
||||||
handleClosePopover();
|
handleClose();
|
||||||
},
|
},
|
||||||
[handleClosePopover, insertTrigger]
|
[handleClose, insertEmbedding]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||||
@@ -86,7 +90,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
|||||||
isOpen,
|
isOpen,
|
||||||
onClose,
|
onClose,
|
||||||
onOpen,
|
onOpen,
|
||||||
onSelect,
|
onSelectEmbedding,
|
||||||
onKeyDown,
|
onKeyDown,
|
||||||
onFocus,
|
onFocus,
|
||||||
};
|
};
|
||||||
@@ -0,0 +1,228 @@
|
|||||||
|
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||||
|
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||||
|
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||||
|
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||||
|
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||||
|
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||||
|
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { isNil, omitBy } from 'lodash-es';
|
||||||
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
|
import { useForm } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
export const AdvancedImport = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const {
|
||||||
|
register,
|
||||||
|
handleSubmit,
|
||||||
|
control,
|
||||||
|
formState: { errors },
|
||||||
|
setValue,
|
||||||
|
resetField,
|
||||||
|
reset,
|
||||||
|
watch,
|
||||||
|
} = useForm<AnyModelConfig>({
|
||||||
|
defaultValues: {
|
||||||
|
name: '',
|
||||||
|
base: 'sd-1',
|
||||||
|
type: 'main',
|
||||||
|
path: '',
|
||||||
|
description: '',
|
||||||
|
format: 'diffusers',
|
||||||
|
vae: '',
|
||||||
|
variant: 'normal',
|
||||||
|
},
|
||||||
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||||
|
(values) => {
|
||||||
|
installModel({
|
||||||
|
source: values.path,
|
||||||
|
config: omitBy(values, isNil),
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelAdded', {
|
||||||
|
modelName: values.name,
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
reset();
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[installModel, dispatch, t, reset]
|
||||||
|
);
|
||||||
|
|
||||||
|
const watchedModelType = watch('type');
|
||||||
|
const watchedModelFormat = watch('format');
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (watchedModelType === 'main') {
|
||||||
|
setValue('format', 'diffusers');
|
||||||
|
setValue('repo_variant', '');
|
||||||
|
setValue('variant', 'normal');
|
||||||
|
}
|
||||||
|
if (watchedModelType === 'lora') {
|
||||||
|
setValue('format', 'lycoris');
|
||||||
|
} else if (watchedModelType === 'embedding') {
|
||||||
|
setValue('format', 'embedding_file');
|
||||||
|
} else if (watchedModelType === 'ip_adapter') {
|
||||||
|
setValue('format', 'invokeai');
|
||||||
|
} else {
|
||||||
|
setValue('format', 'diffusers');
|
||||||
|
}
|
||||||
|
resetField('upcast_attention');
|
||||||
|
resetField('ztsnr_training');
|
||||||
|
resetField('vae');
|
||||||
|
resetField('config');
|
||||||
|
resetField('prediction_type');
|
||||||
|
resetField('image_encoder_model_id');
|
||||||
|
}, [watchedModelType, resetField, setValue]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ScrollableContent>
|
||||||
|
<form onSubmit={handleSubmit(onSubmit)}>
|
||||||
|
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||||
|
<Flex alignItems="flex-end" gap="4">
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||||
|
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||||
|
</FormControl>
|
||||||
|
<Text px="2" fontSize="xs" textAlign="center">
|
||||||
|
{t('modelManager.advancedImportInfo')}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||||
|
<FormControl isInvalid={Boolean(errors.name)}>
|
||||||
|
<Flex direction="column" width="full">
|
||||||
|
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('name', {
|
||||||
|
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
<Flex>
|
||||||
|
<FormControl>
|
||||||
|
<Flex direction="column" width="full">
|
||||||
|
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||||
|
<Textarea size="sm" {...register('description')} />
|
||||||
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
|
<BaseModelSelect control={control} name="base" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('common.format')}</FormLabel>
|
||||||
|
<ModelFormatSelect control={control} name="format" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||||
|
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('path', {
|
||||||
|
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
{watchedModelType === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={4}>
|
||||||
|
{watchedModelFormat === 'diffusers' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||||
|
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{watchedModelFormat === 'checkpoint' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
|
<Input {...register('config')} />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
|
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
|
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||||
|
<Input {...register('vae')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{watchedModelType === 'ip_adapter' && (
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||||
|
<Input {...register('image_encoder_model_id')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
<Button mt={2} type="submit">
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
</ScrollableContent>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
|||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
import { ImportQueueItem } from './ImportQueueItem';
|
||||||
|
|
||||||
export const ModelInstallQueue = () => {
|
export const ImportQueue = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useListModelInstallsQuery();
|
const { data } = useGetModelImportsQuery();
|
||||||
|
|
||||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
const [pruneModelImports] = usePruneModelImportsMutation();
|
||||||
|
|
||||||
const pruneCompletedModelInstalls = useCallback(() => {
|
const pruneQueue = useCallback(() => {
|
||||||
_pruneCompletedModelInstalls()
|
pruneModelImports()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@@ -41,7 +41,7 @@ export const ModelInstallQueue = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [_pruneCompletedModelInstalls, dispatch]);
|
}, [pruneModelImports, dispatch]);
|
||||||
|
|
||||||
const pruneAvailable = useMemo(() => {
|
const pruneAvailable = useMemo(() => {
|
||||||
return data?.some(
|
return data?.some(
|
||||||
@@ -53,19 +53,14 @@ export const ModelInstallQueue = () => {
|
|||||||
<Flex flexDir="column" p={3} h="full">
|
<Flex flexDir="column" p={3} h="full">
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Text>{t('modelManager.importQueue')}</Text>
|
<Text>{t('modelManager.importQueue')}</Text>
|
||||||
<Button
|
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||||
size="sm"
|
|
||||||
isDisabled={!pruneAvailable}
|
|
||||||
onClick={pruneCompletedModelInstalls}
|
|
||||||
tooltip={t('modelManager.pruneTooltip')}
|
|
||||||
>
|
|
||||||
{t('modelManager.prune')}
|
{t('modelManager.prune')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column-reverse" gap="2">
|
<Flex flexDir="column-reverse" gap="2">
|
||||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
</Box>
|
</Box>
|
||||||
@@ -6,24 +6,17 @@ import type { ModelInstallStatus } from 'services/api/types';
|
|||||||
const STATUSES = {
|
const STATUSES = {
|
||||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||||
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
|
||||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelInstallQueueBadge = ({
|
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||||
status,
|
|
||||||
errorReason,
|
|
||||||
}: {
|
|
||||||
status?: ModelInstallStatus;
|
|
||||||
errorReason?: string | null;
|
|
||||||
}) => {
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||||
return null;
|
return <></>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -32,4 +25,4 @@ const ModelInstallQueueBadge = ({
|
|||||||
</Tooltip>
|
</Tooltip>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
export default memo(ModelInstallQueueBadge);
|
export default memo(ImportQueueBadge);
|
||||||
@@ -3,16 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { isNil } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||||
|
|
||||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
import ImportQueueBadge from './ImportQueueBadge';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
installJob: ModelInstallJob;
|
model: ModelInstallJob;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatBytes = (bytes: number) => {
|
const formatBytes = (bytes: number) => {
|
||||||
@@ -27,26 +26,26 @@ const formatBytes = (bytes: number) => {
|
|||||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||||
const { installJob } = props;
|
const { model } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||||
|
|
||||||
const source = useMemo(() => {
|
const source = useMemo(() => {
|
||||||
if (installJob.source.type === 'hf') {
|
if (model.source.type === 'hf') {
|
||||||
return installJob.source as HFModelSource;
|
return model.source as HFModelSource;
|
||||||
} else if (installJob.source.type === 'local') {
|
} else if (model.source.type === 'local') {
|
||||||
return installJob.source as LocalModelSource;
|
return model.source as LocalModelSource;
|
||||||
} else if (installJob.source.type === 'url') {
|
} else if (model.source.type === 'url') {
|
||||||
return installJob.source as URLModelSource;
|
return model.source as URLModelSource;
|
||||||
} else {
|
} else {
|
||||||
return installJob.source as LocalModelSource;
|
return model.source as LocalModelSource;
|
||||||
}
|
}
|
||||||
}, [installJob.source]);
|
}, [model.source]);
|
||||||
|
|
||||||
const handleDeleteModelImport = useCallback(() => {
|
const handleDeleteModelImport = useCallback(() => {
|
||||||
deleteImportModel(installJob.id)
|
deleteImportModel(model.id)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@@ -70,7 +69,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [deleteImportModel, installJob, dispatch]);
|
}, [deleteImportModel, model, dispatch]);
|
||||||
|
|
||||||
const modelName = useMemo(() => {
|
const modelName = useMemo(() => {
|
||||||
switch (source.type) {
|
switch (source.type) {
|
||||||
@@ -86,23 +85,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
}, [source]);
|
}, [source]);
|
||||||
|
|
||||||
const progressValue = useMemo(() => {
|
const progressValue = useMemo(() => {
|
||||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (installJob.total_bytes === 0) {
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
return (model.bytes / model.total_bytes) * 100;
|
||||||
}, [installJob.bytes, installJob.total_bytes]);
|
}, [model.bytes, model.total_bytes]);
|
||||||
|
|
||||||
const progressString = useMemo(() => {
|
const progressString = useMemo(() => {
|
||||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
}, [model.bytes, model.total_bytes, model.status]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap="2" w="full" alignItems="center">
|
<Flex gap="2" w="full" alignItems="center">
|
||||||
@@ -114,21 +109,19 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
<Flex flexDir="column" flex={1}>
|
<Flex flexDir="column" flex={1}>
|
||||||
<Tooltip label={progressString}>
|
<Tooltip label={progressString}>
|
||||||
<Progress
|
<Progress
|
||||||
value={progressValue ?? 0}
|
value={progressValue}
|
||||||
isIndeterminate={progressValue === null}
|
isIndeterminate={progressValue === undefined}
|
||||||
aria-label={t('accessibility.invokeProgressBar')}
|
aria-label={t('accessibility.invokeProgressBar')}
|
||||||
h={2}
|
h={2}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box minW="100px" textAlign="center">
|
<Box minW="100px" textAlign="center">
|
||||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
<Box minW="20px">
|
<Box minW="20px">
|
||||||
{(installJob.status === 'downloading' ||
|
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
||||||
installJob.status === 'waiting' ||
|
|
||||||
installJob.status === 'running') && (
|
|
||||||
<IconButton
|
<IconButton
|
||||||
isRound={true}
|
isRound={true}
|
||||||
size="xs"
|
size="xs"
|
||||||
@@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
|||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelsResults } from './ScanFolderResults';
|
import { ScanModelsResults } from './ScanModelsResults';
|
||||||
|
|
||||||
export const ScanModelsForm = () => {
|
export const ScanModelsForm = () => {
|
||||||
const [scanPath, setScanPath] = useState('');
|
const [scanPath, setScanPath] = useState('');
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||||
|
|
||||||
const scanFolder = useCallback(async () => {
|
const handleSubmitScan = useCallback(async () => {
|
||||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
_scanModels({ scan_path: scanPath }).catch((error) => {
|
||||||
if (error) {
|
if (error) {
|
||||||
setErrorMessage(error.data.detail);
|
setErrorMessage(error.data.detail);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [_scanFolder, scanPath]);
|
}, [_scanModels, scanPath]);
|
||||||
|
|
||||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
setScanPath(e.target.value);
|
setScanPath(e.target.value);
|
||||||
@@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
|||||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||||
{t('modelManager.scanFolder')}
|
{t('modelManager.scanFolder')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
@@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelResultItem } from './ScanFolderResultItem';
|
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||||
|
|
||||||
type ScanModelResultsProps = {
|
type ScanModelResultsProps = {
|
||||||
results: ScanFolderResponse;
|
results: ScanFolderResponse;
|
||||||
@@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
|||||||
location: string;
|
location: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const InstallModelForm = () => {
|
export const SimpleImport = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
const [installModel, { isLoading }] = useInstallModelMutation();
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
||||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
||||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
||||||
|
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
||||||
|
|
||||||
export const InstallModels = () => {
|
export const ImportModels = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||||
@@ -16,11 +17,15 @@ export const InstallModels = () => {
|
|||||||
<Tabs variant="collapse" height="100%">
|
<Tabs variant="collapse" height="100%">
|
||||||
<TabList>
|
<TabList>
|
||||||
<Tab>{t('common.simple')}</Tab>
|
<Tab>{t('common.simple')}</Tab>
|
||||||
|
<Tab>{t('modelManager.advanced')}</Tab>
|
||||||
<Tab>{t('modelManager.scan')}</Tab>
|
<Tab>{t('modelManager.scan')}</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
<TabPanels p={3} height="100%">
|
<TabPanels p={3} height="100%">
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
<InstallModelForm />
|
<SimpleImport />
|
||||||
|
</TabPanel>
|
||||||
|
<TabPanel height="100%">
|
||||||
|
<AdvancedImport />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel height="100%">
|
<TabPanel height="100%">
|
||||||
<ScanModelsForm />
|
<ScanModelsForm />
|
||||||
@@ -29,7 +34,7 @@ export const InstallModels = () => {
|
|||||||
</Tabs>
|
</Tabs>
|
||||||
</Box>
|
</Box>
|
||||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
||||||
<ModelInstallQueue />
|
<ImportQueue />
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
@@ -5,7 +5,7 @@ import { useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoFilter } from 'react-icons/io5';
|
import { IoFilter } from 'react-icons/io5';
|
||||||
|
|
||||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||||
main: 'Main',
|
main: 'Main',
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: 'Textual Inversion',
|
embedding: 'Textual Inversion',
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { InstallModels } from './InstallModels';
|
import { ImportModels } from './ImportModels';
|
||||||
import { Model } from './ModelPanel/Model';
|
import { Model } from './ModelPanel/Model';
|
||||||
|
|
||||||
export const ModelPane = () => {
|
export const ModelPane = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import { Text } from '@invoke-ai/ui-library';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||||
|
|
||||||
@@ -24,9 +23,8 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
|||||||
|
|
||||||
export const DefaultSettings = () => {
|
export const DefaultSettings = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||||
useAppSelector(initialStatesSelector);
|
useAppSelector(initialStatesSelector);
|
||||||
|
|
||||||
@@ -61,7 +59,7 @@ export const DefaultSettings = () => {
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>{t('common.loading')}</Text>;
|
return <Loading />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoPencil } from 'react-icons/io5';
|
import { IoPencil } from 'react-icons/io5';
|
||||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||||
@@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||||
|
|
||||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||||
defaultValues: defaultSettingsDefaults,
|
defaultValues: defaultSettingsDefaults,
|
||||||
@@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
|
|||||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
updateModel({
|
editModelMetadata({
|
||||||
key: selectedModelKey,
|
key: selectedModelKey,
|
||||||
body: { default_settings: body },
|
body: { default_settings: body },
|
||||||
})
|
})
|
||||||
@@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[selectedModelKey, dispatch, updateModel, t]
|
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
|
|||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { Control } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
@@ -14,12 +14,8 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
];
|
];
|
||||||
|
|
||||||
type Props = {
|
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||||
control: Control<UpdateModelArg['body']>;
|
const { field } = useController(props);
|
||||||
};
|
|
||||||
|
|
||||||
const BaseModelSelect = ({ control }: Props) => {
|
|
||||||
const { field } = useController({ control, name: 'base' });
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'none', label: '-' },
|
||||||
|
{ value: 'true', label: 'True' },
|
||||||
|
{ value: 'false', label: 'False' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(BooleanSelect);
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController, useWatch } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||||
|
const { field, formState } = useController(props);
|
||||||
|
const type = useWatch({ control: props.control, name: 'type' });
|
||||||
|
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = useMemo(() => {
|
||||||
|
const modelType = type || formState.defaultValues?.type;
|
||||||
|
if (modelType === 'lora') {
|
||||||
|
return [
|
||||||
|
{ value: 'lycoris', label: 'LyCORIS' },
|
||||||
|
{ value: 'diffusers', label: 'Diffusers' },
|
||||||
|
];
|
||||||
|
} else if (modelType === 'embedding') {
|
||||||
|
return [
|
||||||
|
{ value: 'embedding_file', label: 'Embedding File' },
|
||||||
|
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||||
|
];
|
||||||
|
} else if (modelType === 'ip_adapter') {
|
||||||
|
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||||
|
} else {
|
||||||
|
return [
|
||||||
|
{ value: 'diffusers', label: 'Diffusers' },
|
||||||
|
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
}, [type, formState.defaultValues?.type]);
|
||||||
|
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||||
|
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelFormatSelect);
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||||
|
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||||
|
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||||
|
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||||
|
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||||
|
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||||
|
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelTypeSelect);
|
||||||
@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import { Combobox } from '@invoke-ai/ui-library';
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { Control } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'normal', label: 'Normal' },
|
{ value: 'normal', label: 'Normal' },
|
||||||
@@ -12,12 +12,8 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'depth', label: 'Depth' },
|
{ value: 'depth', label: 'Depth' },
|
||||||
];
|
];
|
||||||
|
|
||||||
type Props = {
|
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
control: Control<UpdateModelArg['body']>;
|
const { field } = useController(props);
|
||||||
};
|
|
||||||
|
|
||||||
const ModelVariantSelect = ({ control }: Props) => {
|
|
||||||
const { field } = useController({ control, name: 'variant' });
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import { Combobox } from '@invoke-ai/ui-library';
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { Control } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'none', label: '-' },
|
{ value: 'none', label: '-' },
|
||||||
@@ -13,12 +13,8 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'sample', label: 'sample' },
|
{ value: 'sample', label: 'sample' },
|
||||||
];
|
];
|
||||||
|
|
||||||
type Props = {
|
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
control: Control<UpdateModelArg['body']>;
|
const { field } = useController(props);
|
||||||
};
|
|
||||||
|
|
||||||
const PredictionTypeSelect = ({ control }: Props) => {
|
|
||||||
const { field } = useController({ control, name: 'prediction_type' });
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'none', label: '-' },
|
||||||
|
{ value: 'fp16', label: 'fp16' },
|
||||||
|
{ value: 'fp32', label: 'fp32' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(RepoVariantSelect);
|
||||||
@@ -1,41 +1,18 @@
|
|||||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
import { useMemo } from 'react';
|
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { ModelType } from 'services/api/types';
|
|
||||||
|
|
||||||
import { TriggerPhrases } from './TriggerPhrases';
|
|
||||||
|
|
||||||
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
|
|
||||||
|
|
||||||
export const ModelMetadata = () => {
|
export const ModelMetadata = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const shouldShowTriggerPhraseSettings = useMemo(() => {
|
|
||||||
if (!data?.type) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
|
|
||||||
}, [data]);
|
|
||||||
|
|
||||||
const apiResponseFormatted = useMemo(() => {
|
|
||||||
if (!data?.source_api_response) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return JSON.parse(data.source_api_response);
|
|
||||||
}, [data?.source_api_response]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" height="full" gap="3">
|
<>
|
||||||
{shouldShowTriggerPhraseSettings && (
|
<Flex flexDir="column" height="full" gap="3">
|
||||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
<DataViewer label="metadata" data={metadata || {}} />
|
||||||
<TriggerPhrases />
|
</Flex>
|
||||||
</Box>
|
</>
|
||||||
)}
|
|
||||||
<DataViewer label="metadata" data={apiResponseFormatted} />
|
|
||||||
</Flex>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
Input,
|
|
||||||
Tag,
|
|
||||||
TagCloseButton,
|
|
||||||
TagLabel,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
|
||||||
import type { ChangeEvent } from 'react';
|
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const TriggerPhrases = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
const [phrase, setPhrase] = useState('');
|
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
|
||||||
|
|
||||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
setPhrase(e.target.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const triggerPhrases = useMemo(() => {
|
|
||||||
return modelConfig?.trigger_phrases || [];
|
|
||||||
}, [modelConfig?.trigger_phrases]);
|
|
||||||
|
|
||||||
const errors = useMemo(() => {
|
|
||||||
const errors = [];
|
|
||||||
|
|
||||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
|
||||||
errors.push('Phrase is already in list');
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors;
|
|
||||||
}, [phrase, triggerPhrases]);
|
|
||||||
|
|
||||||
const addTriggerPhrase = useCallback(async () => {
|
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
await updateModel({
|
|
||||||
key: selectedModelKey,
|
|
||||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
|
||||||
}).unwrap();
|
|
||||||
setPhrase('');
|
|
||||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
|
||||||
|
|
||||||
const removeTriggerPhrase = useCallback(
|
|
||||||
async (phraseToRemove: string) => {
|
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
|
||||||
|
|
||||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
|
||||||
},
|
|
||||||
[updateModel, selectedModelKey, triggerPhrases]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDir="column" w="full" gap="5">
|
|
||||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
|
||||||
<form>
|
|
||||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
|
||||||
<Flex flexDir="column" w="full">
|
|
||||||
<Flex gap="3" alignItems="center" w="full">
|
|
||||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
|
||||||
<Button
|
|
||||||
type="submit"
|
|
||||||
onClick={addTriggerPhrase}
|
|
||||||
isDisabled={Boolean(errors.length)}
|
|
||||||
isLoading={isLoading}
|
|
||||||
>
|
|
||||||
{t('common.add')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
</form>
|
|
||||||
|
|
||||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
|
||||||
{triggerPhrases.map((phrase, index) => (
|
|
||||||
<Tag size="md" key={index}>
|
|
||||||
<TagLabel>{phrase}</TagLabel>
|
|
||||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
|
||||||
</Tag>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
|||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
import type { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
@@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
|
|||||||
const { model } = props;
|
const { model } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const modelConvertHandler = useCallback(() => {
|
const modelConvertHandler = useCallback(() => {
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
Checkbox,
|
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormErrorMessage,
|
FormErrorMessage,
|
||||||
@@ -20,27 +19,66 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
|
import BooleanSelect from './Fields/BooleanSelect';
|
||||||
|
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||||
|
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
|
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||||
|
|
||||||
export const ModelEdit = () => {
|
export const ModelEdit = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
// const modelData = useMemo(() => {
|
||||||
|
// if (!data) {
|
||||||
|
// return null;
|
||||||
|
// }
|
||||||
|
// const modelFormat = data.format;
|
||||||
|
// const modelType = data.type;
|
||||||
|
|
||||||
|
// if (modelType === 'main') {
|
||||||
|
// if (modelFormat === 'diffusers') {
|
||||||
|
// return data as DiffusersModelConfig;
|
||||||
|
// } else if (modelFormat === 'checkpoint') {
|
||||||
|
// return data as CheckpointModelConfig;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// switch (modelType) {
|
||||||
|
// case 'lora':
|
||||||
|
// return data as LoRAModelConfig;
|
||||||
|
// case 'embedding':
|
||||||
|
// return data as TextualInversionModelConfig;
|
||||||
|
// case 't2i_adapter':
|
||||||
|
// return data as T2IAdapterModelConfig;
|
||||||
|
// case 'ip_adapter':
|
||||||
|
// return data as IPAdapterModelConfig;
|
||||||
|
// case 'controlnet':
|
||||||
|
// return data as ControlNetModelConfig;
|
||||||
|
// case 'vae':
|
||||||
|
// return data as VAEModelConfig;
|
||||||
|
// default:
|
||||||
|
// return null;
|
||||||
|
// }
|
||||||
|
// }, [data]);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
register,
|
register,
|
||||||
handleSubmit,
|
handleSubmit,
|
||||||
control,
|
control,
|
||||||
formState: { errors },
|
formState: { errors },
|
||||||
reset,
|
reset,
|
||||||
|
watch,
|
||||||
} = useForm<UpdateModelArg['body']>({
|
} = useForm<UpdateModelArg['body']>({
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
...data,
|
...data,
|
||||||
@@ -48,7 +86,10 @@ export const ModelEdit = () => {
|
|||||||
mode: 'onChange',
|
mode: 'onChange',
|
||||||
});
|
});
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
const watchedModelType = watch('type');
|
||||||
|
const watchedModelFormat = watch('format');
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||||
(values) => {
|
(values) => {
|
||||||
if (!data?.key) {
|
if (!data?.key) {
|
||||||
return;
|
return;
|
||||||
@@ -102,31 +143,33 @@ export const ModelEdit = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
<form onSubmit={handleSubmit(onSubmit)}>
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
{...register('name', {
|
{...register('name', {
|
||||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||||
})}
|
})}
|
||||||
size="lg"
|
size="lg"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
<Flex gap={2}>
|
||||||
</FormControl>
|
<Button size="sm" onClick={handleClickCancel}>
|
||||||
<Button size="sm" onClick={handleClickCancel}>
|
{t('common.cancel')}
|
||||||
{t('common.cancel')}
|
</Button>
|
||||||
</Button>
|
<Button
|
||||||
<Button
|
size="sm"
|
||||||
size="sm"
|
colorScheme="invokeYellow"
|
||||||
colorScheme="invokeYellow"
|
onClick={handleSubmit(onSubmit)}
|
||||||
onClick={handleSubmit(onSubmit)}
|
isLoading={isSubmitting}
|
||||||
isLoading={isSubmitting}
|
isDisabled={Boolean(Object.keys(errors).length)}
|
||||||
isDisabled={Boolean(Object.keys(errors).length)}
|
>
|
||||||
>
|
{t('common.save')}
|
||||||
{t('common.save')}
|
</Button>
|
||||||
</Button>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
<Flex flexDir="column" gap={3} mt="4">
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
<Flex>
|
<Flex>
|
||||||
@@ -141,22 +184,76 @@ export const ModelEdit = () => {
|
|||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
<BaseModelSelect control={control} />
|
<BaseModelSelect control={control} name="base" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||||
|
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('common.format')}</FormLabel>
|
||||||
|
<ModelFormatSelect control={control} name="format" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||||
|
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('path', {
|
||||||
|
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
{watchedModelType === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={4}>
|
||||||
|
{watchedModelFormat === 'diffusers' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||||
|
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{watchedModelFormat === 'checkpoint' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
|
<Input {...register('config')} />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
|
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
|
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||||
|
<Input {...register('vae')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{watchedModelType === 'ip_adapter' && (
|
||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||||
<ModelVariantSelect control={control} />
|
<Input {...register('image_encoder_model_id')} />
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
|
||||||
<PredictionTypeSelect control={control} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
|
||||||
<Checkbox {...register('upcast_attention')} />
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -91,19 +91,26 @@ export const ModelView = () => {
|
|||||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||||
</Flex>
|
</Flex>
|
||||||
{modelData.type === 'main' && (
|
{modelData.type === 'main' && (
|
||||||
<Flex gap={2}>
|
<>
|
||||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
<Flex gap={2}>
|
||||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
{modelData.format === 'diffusers' && (
|
||||||
)}
|
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||||
{modelData.format === 'checkpoint' && (
|
)}
|
||||||
<>
|
{modelData.format === 'checkpoint' && (
|
||||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
)}
|
||||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
|
||||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||||
</>
|
</Flex>
|
||||||
)}
|
<Flex gap={2}>
|
||||||
</Flex>
|
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||||
|
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||||
|
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
{modelData.type === 'ip_adapter' && (
|
{modelData.type === 'ip_adapter' && (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||||
|
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||||
|
import { usePrompt } from 'features/embedding/usePrompt';
|
||||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
|
||||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
|
||||||
import { usePrompt } from 'features/prompt/usePrompt';
|
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
@@ -19,14 +19,19 @@ export const ParamNegativePrompt = memo(() => {
|
|||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||||
prompt,
|
prompt,
|
||||||
textareaRef,
|
textareaRef,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
<EmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={onSelectEmbedding}
|
||||||
|
width={textareaRef.current?.clientWidth}
|
||||||
|
>
|
||||||
<Box pos="relative">
|
<Box pos="relative">
|
||||||
<Textarea
|
<Textarea
|
||||||
id="negativePrompt"
|
id="negativePrompt"
|
||||||
@@ -40,10 +45,10 @@ export const ParamNegativePrompt = memo(() => {
|
|||||||
variant="darkFilled"
|
variant="darkFilled"
|
||||||
/>
|
/>
|
||||||
<PromptOverlayButtonWrapper>
|
<PromptOverlayButtonWrapper>
|
||||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
</PromptOverlayButtonWrapper>
|
</PromptOverlayButtonWrapper>
|
||||||
</Box>
|
</Box>
|
||||||
</PromptPopover>
|
</EmbeddingPopover>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
||||||
|
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||||
|
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||||
|
import { usePrompt } from 'features/embedding/usePrompt';
|
||||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
|
||||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
|
||||||
import { usePrompt } from 'features/prompt/usePrompt';
|
|
||||||
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||||
@@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
|
|||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||||
prompt,
|
prompt,
|
||||||
textareaRef: textareaRef,
|
textareaRef: textareaRef,
|
||||||
onChange: handleChange,
|
onChange: handleChange,
|
||||||
@@ -42,7 +42,12 @@ export const ParamPositivePrompt = memo(() => {
|
|||||||
useHotkeys('alt+a', focus, []);
|
useHotkeys('alt+a', focus, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
<EmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={onSelectEmbedding}
|
||||||
|
width={textareaRef.current?.clientWidth}
|
||||||
|
>
|
||||||
<Box pos="relative">
|
<Box pos="relative">
|
||||||
<Textarea
|
<Textarea
|
||||||
id="prompt"
|
id="prompt"
|
||||||
@@ -56,12 +61,12 @@ export const ParamPositivePrompt = memo(() => {
|
|||||||
variant="darkFilled"
|
variant="darkFilled"
|
||||||
/>
|
/>
|
||||||
<PromptOverlayButtonWrapper>
|
<PromptOverlayButtonWrapper>
|
||||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
||||||
<ShowDynamicPromptsPreviewButton />
|
<ShowDynamicPromptsPreviewButton />
|
||||||
</PromptOverlayButtonWrapper>
|
</PromptOverlayButtonWrapper>
|
||||||
</Box>
|
</Box>
|
||||||
</PromptPopover>
|
</EmbeddingPopover>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import type {
|
|||||||
ParameterScheduler,
|
ParameterScheduler,
|
||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
@@ -209,6 +210,26 @@ export const generationSlice = createSlice({
|
|||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
|
|
||||||
|
if (defaultModel && !state.model) {
|
||||||
|
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||||
|
|
||||||
|
const result = zParameterModel.safeParse({
|
||||||
|
model_name,
|
||||||
|
base_model,
|
||||||
|
model_type,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
state.model = result.data;
|
||||||
|
|
||||||
|
const optimalDimension = getOptimalDimension(result.data);
|
||||||
|
|
||||||
|
state.width = optimalDimension;
|
||||||
|
state.height = optimalDimension;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (action.payload.sd?.scheduler) {
|
if (action.payload.sd?.scheduler) {
|
||||||
state.scheduler = action.payload.sd.scheduler;
|
state.scheduler = action.payload.sd.scheduler;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
import type { Meta, StoryObj } from '@storybook/react';
|
|
||||||
|
|
||||||
import { PromptTriggerSelect } from './PromptTriggerSelect';
|
|
||||||
import type { PromptTriggerSelectProps } from './types';
|
|
||||||
|
|
||||||
const meta: Meta<typeof PromptTriggerSelect> = {
|
|
||||||
title: 'Feature/Prompt/PromptTriggerSelect',
|
|
||||||
tags: ['autodocs'],
|
|
||||||
component: PromptTriggerSelect,
|
|
||||||
};
|
|
||||||
|
|
||||||
export default meta;
|
|
||||||
type Story = StoryObj<typeof PromptTriggerSelect>;
|
|
||||||
|
|
||||||
const Component = (props: PromptTriggerSelectProps) => {
|
|
||||||
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const Default: Story = {
|
|
||||||
render: Component,
|
|
||||||
};
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
|
||||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
|
||||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { flatten, map } from 'lodash-es';
|
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import {
|
|
||||||
loraModelsAdapterSelectors,
|
|
||||||
textualInversionModelsAdapterSelectors,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
|
||||||
|
|
||||||
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
|
|
||||||
|
|
||||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
|
||||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
|
||||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
|
||||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
|
||||||
|
|
||||||
const _onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
if (!v) {
|
|
||||||
onSelect('');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
onSelect(v.value);
|
|
||||||
},
|
|
||||||
[onSelect]
|
|
||||||
);
|
|
||||||
|
|
||||||
const options = useMemo(() => {
|
|
||||||
const _options: GroupBase<ComboboxOption>[] = [];
|
|
||||||
|
|
||||||
if (tiModels) {
|
|
||||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
|
||||||
.selectAll(tiModels)
|
|
||||||
.filter((ti) => ti.base === currentBaseModel)
|
|
||||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
|
||||||
|
|
||||||
if (embeddingOptions.length > 0) {
|
|
||||||
_options.push({
|
|
||||||
label: t('prompt.compatibleEmbeddings'),
|
|
||||||
options: embeddingOptions,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (loraModels) {
|
|
||||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
|
||||||
.selectAll(loraModels)
|
|
||||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
|
||||||
.map((lora) => {
|
|
||||||
if (lora.trigger_phrases) {
|
|
||||||
return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
|
|
||||||
}
|
|
||||||
return [];
|
|
||||||
})
|
|
||||||
.flatMap((x) => x);
|
|
||||||
|
|
||||||
if (triggerPhraseOptions.length > 0) {
|
|
||||||
_options.push({
|
|
||||||
label: t('modelManager.triggerPhrases'),
|
|
||||||
options: flatten(triggerPhraseOptions),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return _options;
|
|
||||||
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<Combobox
|
|
||||||
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
|
|
||||||
defaultMenuIsOpen
|
|
||||||
autoFocus
|
|
||||||
value={null}
|
|
||||||
options={options}
|
|
||||||
noOptionsMessage={noOptionsMessage}
|
|
||||||
onChange={_onChange}
|
|
||||||
onMenuClose={onClose}
|
|
||||||
data-testid="add-prompt-trigger"
|
|
||||||
sx={selectStyles}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
|
|
||||||
|
|
||||||
const selectStyles: ChakraProps['sx'] = {
|
|
||||||
w: 'full',
|
|
||||||
};
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||||
|
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||||
|
import { usePrompt } from 'features/embedding/usePrompt';
|
||||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
|
||||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
|
||||||
import { usePrompt } from 'features/prompt/usePrompt';
|
|
||||||
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
|||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||||
prompt,
|
prompt,
|
||||||
textareaRef: textareaRef,
|
textareaRef: textareaRef,
|
||||||
onChange: handleChange,
|
onChange: handleChange,
|
||||||
@@ -29,7 +29,12 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
|||||||
useHotkeys('alt+a', onFocus, []);
|
useHotkeys('alt+a', onFocus, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
<EmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={onSelectEmbedding}
|
||||||
|
width={textareaRef.current?.clientWidth}
|
||||||
|
>
|
||||||
<Box pos="relative">
|
<Box pos="relative">
|
||||||
<Textarea
|
<Textarea
|
||||||
id="prompt"
|
id="prompt"
|
||||||
@@ -43,10 +48,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
|||||||
variant="darkFilled"
|
variant="darkFilled"
|
||||||
/>
|
/>
|
||||||
<PromptOverlayButtonWrapper>
|
<PromptOverlayButtonWrapper>
|
||||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
</PromptOverlayButtonWrapper>
|
</PromptOverlayButtonWrapper>
|
||||||
</Box>
|
</Box>
|
||||||
</PromptPopover>
|
</EmbeddingPopover>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||||
|
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||||
|
import { usePrompt } from 'features/embedding/usePrompt';
|
||||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
|
||||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
|
||||||
import { usePrompt } from 'features/prompt/usePrompt';
|
|
||||||
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@@ -19,14 +19,19 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
|||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||||
prompt,
|
prompt,
|
||||||
textareaRef: textareaRef,
|
textareaRef: textareaRef,
|
||||||
onChange: handleChange,
|
onChange: handleChange,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
<EmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={onSelectEmbedding}
|
||||||
|
width={textareaRef.current?.clientWidth}
|
||||||
|
>
|
||||||
<Box pos="relative">
|
<Box pos="relative">
|
||||||
<Textarea
|
<Textarea
|
||||||
id="prompt"
|
id="prompt"
|
||||||
@@ -40,10 +45,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
|||||||
variant="darkFilled"
|
variant="darkFilled"
|
||||||
/>
|
/>
|
||||||
<PromptOverlayButtonWrapper>
|
<PromptOverlayButtonWrapper>
|
||||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
</PromptOverlayButtonWrapper>
|
</PromptOverlayButtonWrapper>
|
||||||
</Box>
|
</Box>
|
||||||
</PromptPopover>
|
</EmbeddingPopover>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
|
import type { JSONObject } from 'common/types';
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
import type { operations, paths } from 'services/api/schema';
|
import type { operations, paths } from 'services/api/schema';
|
||||||
import type {
|
import type {
|
||||||
@@ -23,33 +24,49 @@ export type UpdateModelArg = {
|
|||||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type UpdateModelMetadataArg = {
|
||||||
|
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||||
|
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||||
|
};
|
||||||
|
|
||||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
type UpdateModelMetadataResponse =
|
||||||
|
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
|
type GetModelMetadataResponse =
|
||||||
|
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||||
|
|
||||||
type DeleteModelArg = {
|
type DeleteMainModelArg = {
|
||||||
key: string;
|
key: string;
|
||||||
};
|
};
|
||||||
type DeleteModelResponse = void;
|
|
||||||
|
type DeleteMainModelResponse = void;
|
||||||
|
|
||||||
type ConvertMainModelResponse =
|
type ConvertMainModelResponse =
|
||||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type InstallModelArg = {
|
type InstallModelArg = {
|
||||||
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||||
|
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
||||||
|
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
||||||
|
config?: JSONObject;
|
||||||
|
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
type ListModelInstallsResponse =
|
type ListImportModelsResponse =
|
||||||
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type CancelModelInstallResponse =
|
type DeleteImportModelsResponse =
|
||||||
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
|
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
type PruneCompletedModelInstallsResponse =
|
type PruneModelImportsResponse =
|
||||||
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
export type ScanFolderResponse =
|
export type ScanFolderResponse =
|
||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
@@ -66,7 +83,6 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
|||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
@@ -86,10 +102,6 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
|
|||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
|
||||||
undefined,
|
|
||||||
getSelectorsOptions
|
|
||||||
);
|
|
||||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
@@ -134,7 +146,31 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||||
|
query: (base_models) => {
|
||||||
|
const params: ListModelsArg = {
|
||||||
|
model_type: 'main',
|
||||||
|
base_models,
|
||||||
|
};
|
||||||
|
|
||||||
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
|
return buildModelsUrl(`?${query}`);
|
||||||
|
},
|
||||||
|
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||||
|
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||||
|
query: (key) => {
|
||||||
|
return buildModelsUrl(`i/${key}/metadata`);
|
||||||
|
},
|
||||||
|
providesTags: ['Model'],
|
||||||
|
}),
|
||||||
|
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||||
query: ({ key, body }) => {
|
query: ({ key, body }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`i/${key}`),
|
url: buildModelsUrl(`i/${key}`),
|
||||||
@@ -144,17 +180,28 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||||
query: ({ source }) => {
|
query: ({ key, body }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl('install'),
|
url: buildModelsUrl(`i/${key}/metadata`),
|
||||||
params: { source },
|
method: 'PATCH',
|
||||||
method: 'POST',
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelInstalls'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||||
|
query: ({ source, config, access_token }) => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl('install'),
|
||||||
|
params: { source, access_token },
|
||||||
|
method: 'POST',
|
||||||
|
body: config,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['Model', 'ModelImports'],
|
||||||
|
}),
|
||||||
|
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||||
query: ({ key }) => {
|
query: ({ key }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`i/${key}`),
|
url: buildModelsUrl(`i/${key}`),
|
||||||
@@ -163,7 +210,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
||||||
query: (key) => {
|
query: (key) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`convert/${key}`),
|
url: buildModelsUrl(`convert/${key}`),
|
||||||
@@ -206,57 +253,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
|
|
||||||
query: (arg) => {
|
|
||||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
|
||||||
query: () => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl('install'),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
providesTags: ['ModelInstalls'],
|
|
||||||
}),
|
|
||||||
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
|
|
||||||
query: (id) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`install/${id}`),
|
|
||||||
method: 'DELETE',
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['ModelInstalls'],
|
|
||||||
}),
|
|
||||||
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
|
|
||||||
query: () => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl('install'),
|
|
||||||
method: 'DELETE',
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['ModelInstalls'],
|
|
||||||
}),
|
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
|
||||||
query: (base_models) => {
|
|
||||||
const params: ListModelsArg = {
|
|
||||||
model_type: 'main',
|
|
||||||
base_models,
|
|
||||||
};
|
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
|
||||||
return buildModelsUrl(`?${query}`);
|
|
||||||
},
|
|
||||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
|
||||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||||
@@ -317,6 +313,40 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||||
|
query: (arg) => {
|
||||||
|
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getModelImports: build.query<ListImportModelsResponse, void>({
|
||||||
|
query: () => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`import`),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
providesTags: ['ModelImports'],
|
||||||
|
}),
|
||||||
|
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
|
||||||
|
query: (id) => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`import/${id}`),
|
||||||
|
method: 'DELETE',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['ModelImports'],
|
||||||
|
}),
|
||||||
|
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
|
||||||
|
query: () => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl('import'),
|
||||||
|
method: 'PATCH',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['ModelImports'],
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -330,14 +360,16 @@ export const {
|
|||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelMutation,
|
useUpdateModelsMutation,
|
||||||
useInstallModelMutation,
|
useInstallModelMutation,
|
||||||
useConvertModelMutation,
|
useConvertMainModelsMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useLazyScanFolderQuery,
|
useLazyScanModelsQuery,
|
||||||
useListModelInstallsQuery,
|
useGetModelImportsQuery,
|
||||||
useCancelModelInstallMutation,
|
useGetModelMetadataQuery,
|
||||||
usePruneCompletedModelInstallsMutation,
|
useDeleteModelImportMutation,
|
||||||
|
usePruneModelImportsMutation,
|
||||||
|
useUpdateModelMetadataMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
|
||||||
const upsertModelConfigs = (
|
const upsertModelConfigs = (
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ export const tagTypes = [
|
|||||||
'InvocationCacheStatus',
|
'InvocationCacheStatus',
|
||||||
'Model',
|
'Model',
|
||||||
'ModelConfig',
|
'ModelConfig',
|
||||||
'ModelInstalls',
|
'ModelImports',
|
||||||
'T2IAdapterModel',
|
'T2IAdapterModel',
|
||||||
'MainModel',
|
'MainModel',
|
||||||
'VaeModel',
|
'VaeModel',
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -43,13 +43,14 @@ export type ControlField = S['ControlField'];
|
|||||||
// Model Configs
|
// Model Configs
|
||||||
|
|
||||||
// TODO(MM2): Can we make key required in the pydantic model?
|
// TODO(MM2): Can we make key required in the pydantic model?
|
||||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
export type LoRAModelConfig = S['LoRAConfig'];
|
||||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
|
||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
// TODO(MM2): Can we rename this to T2IAdapterConfig
|
||||||
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
export type T2IAdapterModelConfig = S['T2IConfig'];
|
||||||
|
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
||||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user