mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 17:18:11 -05:00
Compare commits
70 Commits
ryan/dense
...
separate-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd3f5f30dc | ||
|
|
71ee28ac12 | ||
|
|
46c904d08a | ||
|
|
7d5a88b69d | ||
|
|
afa4df1991 | ||
|
|
e30cb4b52f | ||
|
|
ba1f6bf926 | ||
|
|
4a9cca6c2d | ||
|
|
b0275700b3 | ||
|
|
8319aca5f9 | ||
|
|
51a604f907 | ||
|
|
7515d73628 | ||
|
|
2c453aa531 | ||
|
|
2cca6e4c76 | ||
|
|
ef171e890a | ||
|
|
caafbf2f0d | ||
|
|
2db5eaf907 | ||
|
|
f234bf6256 | ||
|
|
cfa78b4052 | ||
|
|
ba1dd4b02b | ||
|
|
bcf58cac59 | ||
|
|
e866d90ab2 | ||
|
|
e8797787cf | ||
|
|
0082ecb22b | ||
|
|
656839fcd1 | ||
|
|
99407c899f | ||
|
|
48119d9010 | ||
|
|
7c9128b253 | ||
|
|
4f9bb00275 | ||
|
|
78895b3e80 | ||
|
|
3030a34b88 | ||
|
|
58fa9c2fac | ||
|
|
a8b6635050 | ||
|
|
6829610a71 | ||
|
|
5551cf8ac4 | ||
|
|
37b969d339 | ||
|
|
c953e61294 | ||
|
|
93dd3c848e | ||
|
|
02bde7bb75 | ||
|
|
3391c19926 | ||
|
|
0f60b1ced4 | ||
|
|
44c40d7d1a | ||
|
|
0b9a212363 | ||
|
|
c3aa985c93 | ||
|
|
7cb0da1f66 | ||
|
|
3534366146 | ||
|
|
f2b5f8753f | ||
|
|
f13f5984c0 | ||
|
|
94e1e64296 | ||
|
|
2411bf53c0 | ||
|
|
9378e47a06 | ||
|
|
4471ea8ad1 | ||
|
|
2c835fd550 | ||
|
|
61b737bb9f | ||
|
|
a8cd3dfc99 | ||
|
|
0cce582f2f | ||
|
|
4347d1c7f7 | ||
|
|
bd4fd9693d | ||
|
|
9b40c28144 | ||
|
|
16a5d718bf | ||
|
|
76cbc745e1 | ||
|
|
0a614943f6 | ||
|
|
e426096d32 | ||
|
|
c561cd751f | ||
|
|
af9298f0ef | ||
|
|
5b74117836 | ||
|
|
38474c9797 | ||
|
|
b880a31039 | ||
|
|
dd31bc4586 | ||
|
|
316573df2d |
@@ -32,7 +32,6 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@@ -63,23 +62,21 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
diffusers model.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@@ -174,7 +170,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@@ -217,27 +213,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@@ -272,28 +268,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str:
|
||||
#### version() -> str
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@@ -352,7 +348,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@@ -427,8 +423,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@@ -443,7 +439,6 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@@ -457,15 +452,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@@ -491,9 +486,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@@ -513,13 +508,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@@ -528,7 +523,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@@ -555,7 +550,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@@ -566,7 +561,6 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@@ -661,7 +654,6 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@@ -675,14 +667,13 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@@ -702,14 +693,13 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur *before* the running event.
|
||||
that downloading events occur _before_ the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@@ -752,14 +742,13 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@@ -774,11 +763,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@@ -953,7 +941,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
| **Value** | **String Value** | **Description** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@@ -991,7 +979,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@@ -1053,12 +1041,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@@ -1120,11 +1108,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@@ -1155,9 +1143,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@@ -1172,13 +1160,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@@ -1200,7 +1187,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
### Example Usage
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@@ -1328,7 +1315,6 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@@ -1434,13 +1418,12 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
@@ -1535,16 +1518,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@@ -1590,14 +1571,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@@ -1609,15 +1590,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
||||
@@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@@ -93,10 +92,9 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@@ -14,15 +12,11 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -31,9 +25,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@@ -49,15 +40,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@@ -68,19 +50,16 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"converted_at": 0,
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@@ -89,50 +68,12 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@@ -212,89 +153,16 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(
|
||||
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@@ -366,19 +234,6 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@@ -395,15 +250,13 @@ async def search_by_metadata_tags(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -415,14 +268,14 @@ async def update_model_record(
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
operation_id="delete_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
async def delete_model(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -443,42 +296,39 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@@ -553,10 +403,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@@ -570,9 +420,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@@ -581,7 +430,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@@ -601,7 +450,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@@ -619,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@@ -699,7 +548,8 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@@ -708,7 +558,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@@ -716,10 +566,6 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@@ -731,66 +577,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
||||
@@ -5,15 +5,7 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
@@ -44,7 +36,7 @@ from .model import ClipField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.2.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -59,10 +51,6 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
mask_weight: float = InputField(default=1.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@@ -130,13 +118,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
mask_weight=self.mask_weight,
|
||||
)
|
||||
)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
@@ -250,7 +232,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.2.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -274,11 +256,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
mask_weight: float = InputField(default=1.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
@@ -340,13 +317,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
mask_weight=self.mask_weight,
|
||||
)
|
||||
)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -395,7 +366,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskField, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
"rectangle_mask",
|
||||
title="Create Rectangle Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Create a rectangular mask."""
|
||||
|
||||
height: int = InputField(description="The height of the entire mask.")
|
||||
width: int = InputField(description="The width of the entire mask.")
|
||||
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||
mask[
|
||||
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||
] = True
|
||||
|
||||
mask_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
mask=MaskField(mask_name=mask_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
@@ -194,12 +194,6 @@ class BoardField(BaseModel):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class MaskField(BaseModel):
|
||||
"""A mask primitive field."""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask.")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
@@ -231,12 +225,7 @@ class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[MaskField] = Field(
|
||||
default=None,
|
||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
mask_weight: float = Field(description="")
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import inspect
|
||||
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
@@ -9,7 +9,6 @@ import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
@@ -56,14 +55,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
Range,
|
||||
SDXLConditioningInfo,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -73,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
@@ -291,11 +284,11 @@ def get_scheduler(
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
@@ -372,190 +365,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.mask_name)
|
||||
text_embeddings_masks.append(mask)
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
conditioning_fields: list[ConditioningField],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
extra_conditioning = None
|
||||
|
||||
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||
mask = masks[prompt_idx]
|
||||
if (
|
||||
text_embedding_info.extra_conditioning is not None
|
||||
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
):
|
||||
extra_conditioning = text_embedding_info.extra_conditioning
|
||||
|
||||
if is_sdxl:
|
||||
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||
# mask, but we don't enforce this.
|
||||
|
||||
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
if pooled_embedding is None or mask is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None or mask is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
# embedding_ranges.append(
|
||||
# Range(
|
||||
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
# )
|
||||
# )
|
||||
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
|
||||
# in the count.
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len + 1,
|
||||
end=cur_text_embedding_len
|
||||
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
|
||||
)
|
||||
)
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regions = None
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(
|
||||
masks=torch.cat(processed_masks, dim=1),
|
||||
ranges=embedding_ranges,
|
||||
mask_weights=[x.mask_weight for x in conditioning_fields],
|
||||
)
|
||||
|
||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||
"prompts."
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=extra_conditioning,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
), regions
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=extra_conditioning,
|
||||
), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> TextConditioningData:
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = self.negative_conditioning
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
conditioning_fields=cond_list,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
conditioning_fields=uncond_list,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
# for ancestral and sde schedulers
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@@ -661,6 +503,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
conditioning_data: ConditioningData,
|
||||
exit_stack: ExitStack,
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
@@ -677,6 +520,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return None
|
||||
|
||||
ip_adapter_data_list = []
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
@@ -699,13 +543,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -795,7 +642,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
@@ -824,15 +670,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
scheduler_step_kwargs = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
if "generator" in scheduler_step_signature.parameters:
|
||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||
# reproducibility.
|
||||
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@@ -925,10 +763,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
@@ -942,16 +777,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
@@ -964,7 +799,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
|
||||
@@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -230,18 +229,6 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
# region Image
|
||||
|
||||
|
||||
@invocation_output("mask_output")
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor.
|
||||
dtype: torch.bool
|
||||
shape: (1, height, width).
|
||||
"""
|
||||
|
||||
mask: MaskField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
@@ -427,6 +414,10 @@ class ConditioningOutput(BaseInvocationOutput):
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_collection_output")
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
@@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -256,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
|
||||
@@ -18,10 +18,9 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
@@ -151,6 +150,13 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
@@ -260,7 +266,6 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@@ -347,6 +352,7 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@@ -392,7 +398,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
`prediction_type`, and/or `image_size`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
||||
@@ -20,12 +20,15 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@@ -35,12 +38,14 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@@ -90,7 +95,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@@ -139,6 +143,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@@ -148,11 +153,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@@ -178,7 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: bool = False,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
@@ -374,15 +379,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@@ -468,7 +476,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, model)
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@@ -520,24 +528,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
config = config or {}
|
||||
|
||||
assert info.original_hash # always assigned by probe()
|
||||
info.key = override_key or info.original_hash
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
@@ -546,11 +545,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@@ -571,13 +570,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace,
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@@ -605,15 +606,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@@ -628,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@@ -656,7 +659,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@@ -843,3 +846,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
||||
@@ -1,81 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
Only limited changes are valid:
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
default=None, description="The user-configured default settings for this model"
|
||||
)
|
||||
"""The user-configured default settings for this model"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
@@ -1,223 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
@@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@@ -60,11 +59,33 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@@ -88,13 +109,12 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -109,40 +129,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
@@ -217,21 +203,3 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
||||
@@ -43,7 +43,7 @@ import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -53,12 +53,11 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@@ -69,7 +68,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@@ -78,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@@ -95,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
INSERT INTO models (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@@ -119,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@@ -132,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
return self.get_model(config.key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@@ -146,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -158,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@@ -199,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -220,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -246,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@@ -265,14 +257,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@@ -281,7 +272,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@@ -292,13 +283,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@@ -307,83 +298,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@@ -394,7 +337,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
||||
@@ -1,6 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
"""
|
||||
Base class for session runner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
|
||||
@@ -2,13 +2,14 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation.id, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run after a node is executed"""
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
"""Processes sessions from the session queue"""
|
||||
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
super().__init__()
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"cancel_event": self._cancel_event,
|
||||
},
|
||||
)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# If we have a on_after_run_session callback, call it
|
||||
if self.on_after_run_session is not None:
|
||||
self.on_after_run_session(self._queue_item)
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
@@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
||||
@@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
||||
@@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
||||
182
invokeai/backend/ip_adapter/attention_processor.py
Normal file
182
invokeai/backend/ip_adapter/attention_processor.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
# tencent-ailab comment:
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
|
||||
|
||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||
# loading.
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
assert len(weights) == len(scales)
|
||||
|
||||
self._weights = weights
|
||||
self._scales = scales
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Apply IP-Adapter attention.
|
||||
|
||||
Args:
|
||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
@@ -1,55 +1,52 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
class UNetPatcher:
|
||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
||||
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
||||
self._ip_adapters = ip_adapters
|
||||
self._ip_adapter_scales = None
|
||||
|
||||
if self._ip_adapters is not None:
|
||||
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||
self._scales = [1.0] * len(self._ip_adapters)
|
||||
|
||||
def set_scale(self, idx: int, value: float):
|
||||
self._ip_adapter_scales[idx] = value
|
||||
self._scales[idx] = value
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them (if IP-Adapters are being applied).
|
||||
weights into them.
|
||||
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
"""
|
||||
# Construct a dict of attention processors based on the UNet's architecture.
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||
# "attn1" processors do not use IP-Adapters.
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
if name.endswith("attn1.processor"):
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
attn_procs[name] = IPAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
self._ip_adapter_scales,
|
||||
self._scales,
|
||||
)
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
||||
|
||||
attn_procs = self._prepare_attention_processors(unet)
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
unet.set_attn_processor(attn_procs)
|
||||
yield None
|
||||
finally:
|
||||
@@ -25,10 +25,13 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
@@ -56,8 +59,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@@ -73,9 +76,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@@ -93,8 +96,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@@ -112,128 +115,187 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
path: str = Field(
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(
|
||||
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||
)
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
|
||||
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
|
||||
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
@@ -242,6 +304,10 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
@@ -249,58 +315,65 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@@ -332,6 +405,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
return model # type: ignore
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
@@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
@@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
|
||||
@@ -15,10 +15,8 @@ Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@@ -27,8 +25,6 @@ from ..config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
@@ -61,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@@ -71,10 +70,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
@@ -90,33 +89,15 @@ class ModelLoaderRegistry:
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
return implementation, config, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
result = None
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
@@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
@@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
@@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
# TODO(psyche): the types on this diffusers method are not correct
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
||||
@@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@@ -16,7 +17,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@@ -54,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@@ -73,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
config_file = config.config_path
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
@@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
|
||||
@@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
@@ -117,7 +118,6 @@ class ModelMerger(object):
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
@@ -134,10 +134,6 @@ class ModelMerger(object):
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
@@ -163,12 +159,10 @@ class ModelMerger(object):
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
model_config.name = merged_model_name
|
||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
||||
|
||||
self._installer.record_store.update_model(
|
||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
||||
@@ -25,9 +25,7 @@ from .metadata_base import (
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
@@ -38,10 +36,8 @@ __all__ = [
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
|
||||
@@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
@@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
@@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
@@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
|
||||
@@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
||||
@@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@@ -25,7 +23,6 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
@@ -35,31 +32,6 @@ class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
@@ -69,24 +41,10 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@@ -124,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@@ -206,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@@ -17,6 +18,7 @@ from .config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@@ -95,8 +97,8 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@@ -108,14 +110,6 @@ class ModelProbe(object):
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
@@ -137,19 +131,21 @@ class ModelProbe(object):
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@@ -161,15 +157,17 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
@@ -179,7 +177,7 @@ class ModelProbe(object):
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.ONNX,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
@@ -213,11 +211,11 @@ class ModelProbe(object):
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@@ -239,7 +237,7 @@ class ModelProbe(object):
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
@@ -285,13 +283,21 @@ class ModelProbe(object):
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
@@ -497,12 +503,12 @@ class FolderProbeBase(ProbeBase):
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
@@ -708,8 +714,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
@@ -717,8 +723,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
||||
@@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
@@ -34,7 +35,7 @@ def filter_files(
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
paths: List[Path] = []
|
||||
root = files[0].parts[0]
|
||||
|
||||
@@ -73,64 +74,81 @@ def filter_files(
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubfolderCandidate:
|
||||
path: Path
|
||||
score: int
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
if variant == ModelRepoVariant.OpenVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
ModelRepoVariant.Default,
|
||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||
# text encoders:
|
||||
#
|
||||
# - text_encoder/model.fp16.safetensors
|
||||
# - text_encoder/model.safetensors
|
||||
# - text_encoder/pytorch_model.bin
|
||||
# - text_encoder/pytorch_model.fp16.bin
|
||||
#
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
subfolder_weights[parent] = []
|
||||
|
||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
||||
@@ -23,12 +23,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterConditioningInfo,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
|
||||
@@ -173,11 +170,10 @@ class ControlNetData:
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
# weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@@ -318,8 +314,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
@@ -379,7 +374,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
@@ -399,8 +393,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
@@ -417,35 +410,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||
use_cross_attention_control = (
|
||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
if use_cross_attention_control and use_ip_adapter:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
||||
)
|
||||
if use_cross_attention_control and use_regional_prompting:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||
)
|
||||
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
if use_cross_attention_control:
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
)
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = True
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
@@ -468,14 +448,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
unet_attention_patcher=unet_attention_patcher,
|
||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
if callback is not None:
|
||||
@@ -497,15 +485,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@@ -528,10 +515,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
unet_attention_patcher.set_scale(i, weight)
|
||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
||||
else:
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
unet_attention_patcher.set_scale(i, 0.0)
|
||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s)
|
||||
down_block_additional_residuals = None
|
||||
@@ -575,17 +562,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
ip_adapter_conditioning = None
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
@@ -605,7 +587,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
|
||||
# TODO: issue to diffusers?
|
||||
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,11 +10,6 @@ from .cross_attention_control import Arguments
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
"""Extra conditioning information produced by Compel.
|
||||
|
||||
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||
"""
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@@ -23,8 +20,6 @@ class ExtraConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
|
||||
@@ -40,8 +35,6 @@ class ConditioningFieldData:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
@@ -51,6 +44,14 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
@@ -64,55 +65,41 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||
|
||||
class TextConditioningRegions:
|
||||
def __init__(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
ranges: list[Range],
|
||||
mask_weights: list[float],
|
||||
):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
self.masks = masks
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
self.mask_weights = mask_weights
|
||||
|
||||
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||
|
||||
Supported custom features:
|
||||
- IP-Adapter
|
||||
- Regional prompt attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_scales: Optional[list[float]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
layer-specific are passed to __init__().
|
||||
|
||||
Args:
|
||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||
for the i'th IP-Adapter.
|
||||
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||
the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
self._ip_adapter_scales = ip_adapter_scales
|
||||
|
||||
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||
if self._ip_adapter_weights is not None:
|
||||
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
# For regional prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[float] = None,
|
||||
# For IP-Adapter:
|
||||
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
|
||||
Args:
|
||||
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||
apply regional prompt masking.
|
||||
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None:
|
||||
assert percent_through is not None
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
if is_cross_attention:
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
)
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
|
||||
else: # self-attention
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||
query_seq_len=query_seq_len,
|
||||
percent_through=percent_through,
|
||||
)
|
||||
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if regional_prompt_data is not None and percent_through < 0.3:
|
||||
# Don't apply to uncond????
|
||||
|
||||
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||
prompt_region_attention_mask, sequence_length, batch_size
|
||||
)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.view(
|
||||
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
|
||||
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
|
||||
|
||||
prompt_region_attention_mask = attn_mask_weight * (
|
||||
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
else:
|
||||
pass
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||
if self._is_ip_adapter_enabled():
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||
assert ip_adapter_image_prompt_embeds is None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
@@ -1,164 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
def __init__(
|
||||
self,
|
||||
regions: list[TextConditioningRegions],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
device (torch.device): The device to use for the attention masks.
|
||||
dtype (torch.dtype): The data type to use for the attention masks.
|
||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||
in steps of 2x.
|
||||
"""
|
||||
self._regions = regions
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||
# sequence length of s.
|
||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
self._negative_cross_attn_mask_score = 0.0
|
||||
self._size_weight = 1.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Prepare the spatial masks for all downscaling factors."""
|
||||
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||
# of s.
|
||||
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
|
||||
for batch_sample_regions in regions:
|
||||
batch_sample_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||
assert b == 1
|
||||
query_seq_len = h * w
|
||||
|
||||
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||
# regions to be lost entirely.
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||
# potentially use a weighted mask rather than a binary mask.
|
||||
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
||||
|
||||
return batch_sample_masks_by_seq_len
|
||||
|
||||
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||
"""Get the cross-attention mask for the given query sequence length.
|
||||
|
||||
Args:
|
||||
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masks.
|
||||
shape: (batch_size, query_seq_len, key_seq_len).
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
batch_sample_regions = self._regions[batch_idx]
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
|
||||
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||
# size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||
# batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
|
||||
mask_weight + self._size_weight * (1 - size)
|
||||
)
|
||||
|
||||
return attn_mask
|
||||
|
||||
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
|
||||
"""Get the self-attention mask for the given query sequence length.
|
||||
|
||||
Args:
|
||||
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masks.
|
||||
shape: (batch_size, query_seq_len, query_seq_len).
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
batch_sample_regions = self._regions[batch_idx]
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||
size = size.to(dtype=prompt_query_mask.dtype)
|
||||
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||
# query_seq_len) mask.
|
||||
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
|
||||
attn_mask[batch_idx, :, :] = torch.maximum(
|
||||
attn_mask[batch_idx, :, :],
|
||||
prompt_query_mask.unsqueeze(0)
|
||||
* prompt_query_mask.unsqueeze(1)
|
||||
* (mask_weight + self._size_weight * (1 - size)),
|
||||
)
|
||||
|
||||
# if attn_mask[batch_idx].max() < 0.01:
|
||||
# attn_mask[batch_idx, ...] = 1.0
|
||||
|
||||
# attn_mask[attn_mask > 0.5] = 1.0
|
||||
# attn_mask[attn_mask <= 0.5] = 0.0
|
||||
# attn_mask_min = attn_mask[batch_idx].min()
|
||||
|
||||
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||
# if abs(attn_mask_min) > 0.0001:
|
||||
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
||||
return attn_mask
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
@@ -11,13 +10,11 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
@@ -59,6 +56,7 @@ class InvokeAIDiffuserComponent:
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
@@ -93,7 +91,7 @@ class InvokeAIDiffuserComponent:
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
conditioning_data: TextConditioningData,
|
||||
conditioning_data,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
@@ -126,30 +124,38 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if conditioning_data.is_sdxl():
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if conditioning_data.is_sdxl():
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
@@ -193,17 +199,16 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
conditioning_data: ConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
)
|
||||
@@ -219,8 +224,6 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
percent_through=percent_through,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
@@ -234,8 +237,6 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
percent_through=percent_through,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@@ -243,6 +244,19 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
@@ -290,13 +304,13 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
percent_through: float,
|
||||
conditioning_data: ConditioningData,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@@ -307,55 +321,41 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
if ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
uncond_text = conditioning_data.uncond_text
|
||||
cond_text = conditioning_data.cond_text
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
uncond_text.embeds, cond_text.embeds
|
||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
||||
)
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
mask_weights=[0.0],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
time.sleep(1.0)
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
@@ -374,10 +374,8 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
conditioning_data: ConditioningData,
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
percent_through: float,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@@ -424,40 +422,36 @@ class InvokeAIDiffuserComponent:
|
||||
# Unconditioned pass
|
||||
#####################
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
cross_attention_kwargs = None
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
@@ -469,41 +463,36 @@ class InvokeAIDiffuserComponent:
|
||||
# Conditioned pass
|
||||
###################
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
cross_attention_kwargs = None
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.cond_text.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
@@ -517,3 +506,64 @@ class InvokeAIDiffuserComponent:
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct is not None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, :, 0 : int(width / 2)],
|
||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct is not None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, 0 : int(height / 2)],
|
||||
y_flipped[:, :, int(height / 2) : int(height)],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
@@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.Lora,
|
||||
model_type=ModelType.LoRA,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
||||
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
|
||||
"fix": "eslint --fix . && prettier --log-level warn --write .",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
||||
@@ -304,6 +304,12 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
|
||||
@@ -153,7 +153,7 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
||||
@@ -7,8 +7,10 @@ import {
|
||||
selectAllT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
@@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
@@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
|
||||
@@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
return;
|
||||
}
|
||||
|
||||
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
if (!modelConfig || !modelConfig.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
|
||||
@@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
@@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
@@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
||||
dispatch(api.util.invalidateTags(['Model']));
|
||||
},
|
||||
});
|
||||
|
||||
@@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
||||
@@ -1,67 +0,0 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
||||
@@ -1,228 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { isNil, omitBy } from 'lodash-es';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
setValue,
|
||||
resetField,
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
installModel({
|
||||
source: values.path,
|
||||
config: omitBy(values, isNil),
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, dispatch, t, reset]
|
||||
);
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
useEffect(() => {
|
||||
if (watchedModelType === 'main') {
|
||||
setValue('format', 'diffusers');
|
||||
setValue('repo_variant', '');
|
||||
setValue('variant', 'normal');
|
||||
}
|
||||
if (watchedModelType === 'lora') {
|
||||
setValue('format', 'lycoris');
|
||||
} else if (watchedModelType === 'embedding') {
|
||||
setValue('format', 'embedding_file');
|
||||
} else if (watchedModelType === 'ip_adapter') {
|
||||
setValue('format', 'invokeai');
|
||||
} else {
|
||||
setValue('format', 'diffusers');
|
||||
}
|
||||
resetField('upcast_attention');
|
||||
resetField('ztsnr_training');
|
||||
resetField('vae');
|
||||
resetField('config');
|
||||
resetField('prediction_type');
|
||||
resetField('image_encoder_model_id');
|
||||
}, [watchedModelType, resetField, setValue]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
</ScrollableContent>
|
||||
);
|
||||
};
|
||||
@@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export const SimpleImport = () => {
|
||||
export const InstallModelForm = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
||||
@@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ImportQueueItem } from './ImportQueueItem';
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ImportQueue = () => {
|
||||
export const ModelInstallQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetModelImportsQuery();
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [pruneModelImports] = usePruneModelImportsMutation();
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
|
||||
const pruneQueue = useCallback(() => {
|
||||
pruneModelImports()
|
||||
const pruneCompletedModelInstalls = useCallback(() => {
|
||||
_pruneCompletedModelInstalls()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@@ -41,7 +41,7 @@ export const ImportQueue = () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [pruneModelImports, dispatch]);
|
||||
}, [_pruneCompletedModelInstalls, dispatch]);
|
||||
|
||||
const pruneAvailable = useMemo(() => {
|
||||
return data?.some(
|
||||
@@ -53,14 +53,19 @@ export const ImportQueue = () => {
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||
<Button
|
||||
size="sm"
|
||||
isDisabled={!pruneAvailable}
|
||||
onClick={pruneCompletedModelInstalls}
|
||||
tooltip={t('modelManager.pruneTooltip')}
|
||||
>
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
@@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
|
||||
const STATUSES = {
|
||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const ModelInstallQueueBadge = ({
|
||||
status,
|
||||
errorReason,
|
||||
}: {
|
||||
status?: ModelInstallStatus;
|
||||
errorReason?: string | null;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return <></>;
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
export default memo(ImportQueueBadge);
|
||||
export default memo(ModelInstallQueueBadge);
|
||||
@@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||
|
||||
import ImportQueueBadge from './ImportQueueBadge';
|
||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: ModelInstallJob;
|
||||
installJob: ModelInstallJob;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
@@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (model.source.type === 'hf') {
|
||||
return model.source as HFModelSource;
|
||||
} else if (model.source.type === 'local') {
|
||||
return model.source as LocalModelSource;
|
||||
} else if (model.source.type === 'url') {
|
||||
return model.source as URLModelSource;
|
||||
if (installJob.source.type === 'hf') {
|
||||
return installJob.source as HFModelSource;
|
||||
} else if (installJob.source.type === 'local') {
|
||||
return installJob.source as LocalModelSource;
|
||||
} else if (installJob.source.type === 'url') {
|
||||
return installJob.source as URLModelSource;
|
||||
} else {
|
||||
return model.source as LocalModelSource;
|
||||
return installJob.source as LocalModelSource;
|
||||
}
|
||||
}, [model.source]);
|
||||
}, [installJob.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(model.id)
|
||||
deleteImportModel(installJob.id)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
}, [deleteImportModel, installJob, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
@@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
}, [source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (installJob.total_bytes === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
@@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box minW="20px">
|
||||
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
||||
{(installJob.status === 'downloading' ||
|
||||
installJob.status === 'waiting' ||
|
||||
installJob.status === 'running') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
||||
@@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanModelsResults';
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = () => {
|
||||
const [scanPath, setScanPath] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||
|
||||
const handleSubmitScan = useCallback(async () => {
|
||||
_scanModels({ scan_path: scanPath }).catch((error) => {
|
||||
const scanFolder = useCallback(async () => {
|
||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}, [_scanModels, scanPath]);
|
||||
}, [_scanFolder, scanPath]);
|
||||
|
||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setScanPath(e.target.value);
|
||||
@@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||
</Flex>
|
||||
|
||||
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
{t('modelManager.scanFolder')}
|
||||
</Button>
|
||||
</Flex>
|
||||
@@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
import { ScanModelResultItem } from './ScanFolderResultItem';
|
||||
|
||||
type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
||||
@@ -1,12 +1,11 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
||||
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
||||
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const ImportModels = () => {
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||
@@ -17,15 +16,11 @@ export const ImportModels = () => {
|
||||
<Tabs variant="collapse" height="100%">
|
||||
<TabList>
|
||||
<Tab>{t('common.simple')}</Tab>
|
||||
<Tab>{t('modelManager.advanced')}</Tab>
|
||||
<Tab>{t('modelManager.scan')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanel>
|
||||
<SimpleImport />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<AdvancedImport />
|
||||
<InstallModelForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
@@ -34,7 +29,7 @@ export const ImportModels = () => {
|
||||
</Tabs>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
||||
<ImportQueue />
|
||||
<ModelInstallQueue />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
@@ -5,7 +5,7 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoFilter } from 'react-icons/io5';
|
||||
|
||||
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { ImportModels } from './ImportModels';
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
@@ -23,8 +24,9 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
@@ -59,7 +61,7 @@ export const DefaultSettings = () => {
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
|
||||
@@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
@@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
@@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
@@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
[selectedModelKey, dispatch, updateModel, t]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
@@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const BaseModelSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'base' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'true', label: 'True' },
|
||||
{ value: 'false', label: 'False' },
|
||||
];
|
||||
|
||||
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(BooleanSelect);
|
||||
@@ -1,47 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field, formState } = useController(props);
|
||||
const type = useWatch({ control: props.control, name: 'type' });
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
const modelType = type || formState.defaultValues?.type;
|
||||
if (modelType === 'lora') {
|
||||
return [
|
||||
{ value: 'lycoris', label: 'LyCORIS' },
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
];
|
||||
} else if (modelType === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||
];
|
||||
} else if (modelType === 'ip_adapter') {
|
||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||
} else {
|
||||
return [
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||
];
|
||||
}
|
||||
}, [type, formState.defaultValues?.type]);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelFormatSelect);
|
||||
@@ -1,32 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||
] as const;
|
||||
|
||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelTypeSelect);
|
||||
@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
@@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const ModelVariantSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'variant' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
||||
@@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
@@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sample', label: 'sample' },
|
||||
];
|
||||
|
||||
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const PredictionTypeSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'prediction_type' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'fp16', label: 'fp16' },
|
||||
{ value: 'fp32', label: 'fp32' },
|
||||
];
|
||||
|
||||
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(RepoVariantSelect);
|
||||
@@ -1,18 +1,41 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
import { TriggerPhrases } from './TriggerPhrases';
|
||||
|
||||
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const shouldShowTriggerPhraseSettings = useMemo(() => {
|
||||
if (!data?.type) {
|
||||
return false;
|
||||
}
|
||||
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
|
||||
}, [data]);
|
||||
|
||||
const apiResponseFormatted = useMemo(() => {
|
||||
if (!data?.source_api_response) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(data.source_api_response);
|
||||
}, [data?.source_api_response]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
<DataViewer label="metadata" data={metadata || {}} />
|
||||
</Flex>
|
||||
</>
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
{shouldShowTriggerPhraseSettings && (
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
)}
|
||||
<DataViewer label="metadata" data={apiResponseFormatted} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
Input,
|
||||
Tag,
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPhrase(e.target.value);
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
return modelConfig?.trigger_phrases || [];
|
||||
}, [modelConfig?.trigger_phrases]);
|
||||
|
||||
const errors = useMemo(() => {
|
||||
const errors = [];
|
||||
|
||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
||||
errors.push('Phrase is already in list');
|
||||
}
|
||||
|
||||
return errors;
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
setPhrase('');
|
||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
||||
<form>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap="3" alignItems="center" w="full">
|
||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index}>
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
interface ModelConvertProps {
|
||||
@@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
@@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import BooleanSelect from './Fields/BooleanSelect';
|
||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||
|
||||
export const ModelEdit = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
// const modelData = useMemo(() => {
|
||||
// if (!data) {
|
||||
// return null;
|
||||
// }
|
||||
// const modelFormat = data.format;
|
||||
// const modelType = data.type;
|
||||
|
||||
// if (modelType === 'main') {
|
||||
// if (modelFormat === 'diffusers') {
|
||||
// return data as DiffusersModelConfig;
|
||||
// } else if (modelFormat === 'checkpoint') {
|
||||
// return data as CheckpointModelConfig;
|
||||
// }
|
||||
// }
|
||||
|
||||
// switch (modelType) {
|
||||
// case 'lora':
|
||||
// return data as LoRAModelConfig;
|
||||
// case 'embedding':
|
||||
// return data as TextualInversionModelConfig;
|
||||
// case 't2i_adapter':
|
||||
// return data as T2IAdapterModelConfig;
|
||||
// case 'ip_adapter':
|
||||
// return data as IPAdapterModelConfig;
|
||||
// case 'controlnet':
|
||||
// return data as ControlNetModelConfig;
|
||||
// case 'vae':
|
||||
// return data as VAEModelConfig;
|
||||
// default:
|
||||
// return null;
|
||||
// }
|
||||
// }, [data]);
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: {
|
||||
...data,
|
||||
@@ -86,10 +48,7 @@ export const ModelEdit = () => {
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
@@ -143,33 +102,31 @@ export const ModelEdit = () => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
})}
|
||||
size="lg"
|
||||
/>
|
||||
|
||||
<Flex gap={2}>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex>
|
||||
@@ -184,76 +141,22 @@ export const ModelEdit = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
<BaseModelSelect control={control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
@@ -91,26 +91,19 @@ export const ModelView = () => {
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="negativePrompt"
|
||||
@@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
@@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
useHotkeys('alt+a', focus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
||||
<ShowDynamicPromptsPreviewButton />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import type {
|
||||
ParameterScheduler,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
@@ -210,26 +209,6 @@ export const generationSlice = createSlice({
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
|
||||
const result = zParameterModel.safeParse({
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
state.model = result.data;
|
||||
|
||||
const optimalDimension = getOptimalDimension(result.data);
|
||||
|
||||
state.width = optimalDimension;
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
|
||||
@@ -8,15 +8,15 @@ type Props = {
|
||||
onOpen: () => void;
|
||||
};
|
||||
|
||||
export const AddEmbeddingButton = memo((props: Props) => {
|
||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
const { onOpen, isOpen } = props;
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('embedding.addEmbedding')}>
|
||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
isDisabled={isOpen}
|
||||
aria-label={t('embedding.addEmbedding')}
|
||||
aria-label={t('prompt.addPromptTrigger')}
|
||||
icon={<PiCodeBold />}
|
||||
onClick={onOpen}
|
||||
/>
|
||||
@@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
const { onSelect, isOpen, onClose, width, children } = props;
|
||||
|
||||
return (
|
||||
@@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
returnFocusOnClose={false}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>{children}</PopoverAnchor>
|
||||
@@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
||||
PromptPopover.displayName = 'PromptPopover';
|
||||
@@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { PromptTriggerSelect } from './PromptTriggerSelect';
|
||||
import type { PromptTriggerSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof PromptTriggerSelect> = {
|
||||
title: 'Feature/Prompt/PromptTriggerSelect',
|
||||
tags: ['autodocs'],
|
||||
component: PromptTriggerSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof PromptTriggerSelect>;
|
||||
|
||||
const Component = (props: PromptTriggerSelectProps) => {
|
||||
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
||||
@@ -0,0 +1,104 @@
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { flatten, map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
loraModelsAdapterSelectors,
|
||||
textualInversionModelsAdapterSelectors,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
onSelect('');
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v.value);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const options = useMemo(() => {
|
||||
const _options: GroupBase<ComboboxOption>[] = [];
|
||||
|
||||
if (tiModels) {
|
||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
||||
.selectAll(tiModels)
|
||||
.filter((ti) => ti.base === currentBaseModel)
|
||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||
|
||||
if (embeddingOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: embeddingOptions,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (loraModels) {
|
||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
||||
.selectAll(loraModels)
|
||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||
.map((lora) => {
|
||||
if (lora.trigger_phrases) {
|
||||
return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
|
||||
}
|
||||
return [];
|
||||
})
|
||||
.flatMap((x) => x);
|
||||
|
||||
if (triggerPhraseOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: flatten(triggerPhraseOptions),
|
||||
});
|
||||
}
|
||||
}
|
||||
return _options;
|
||||
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={_onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-prompt-trigger"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
import type { PropsWithChildren } from 'react';
|
||||
|
||||
export type EmbeddingSelectProps = {
|
||||
export type PromptTriggerSelectProps = {
|
||||
onSelect: (v: string) => void;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||
EmbeddingSelectProps & {
|
||||
export type PromptPopoverProps = PropsWithChildren &
|
||||
PromptTriggerSelectProps & {
|
||||
isOpen: boolean;
|
||||
width?: number | string;
|
||||
};
|
||||
@@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
||||
import { useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
type UseInsertEmbeddingArg = {
|
||||
type UseInsertTriggerArg = {
|
||||
prompt: string;
|
||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||
onChange: (v: string) => void;
|
||||
};
|
||||
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
[_onChange]
|
||||
);
|
||||
|
||||
const insertEmbedding = useCallback(
|
||||
const insertTrigger = useCallback(
|
||||
(v: string) => {
|
||||
if (!textareaRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
// this is where we insert the trigger
|
||||
const caret = textareaRef.current.selectionStart;
|
||||
|
||||
if (isNil(caret)) {
|
||||
@@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
newPrompt += `${v}`;
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
// we insert the cursor after the end of trigger
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
@@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
_onChange(newPrompt);
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
// set the cursor position to just after the trigger
|
||||
textareaRef.current.selectionStart = finalCaretPos;
|
||||
textareaRef.current.selectionEnd = finalCaretPos;
|
||||
},
|
||||
@@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
textareaRef.current?.focus();
|
||||
}, [textareaRef]);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
const handleClosePopover = useCallback(() => {
|
||||
onClose();
|
||||
onFocus();
|
||||
}, [onFocus, onClose]);
|
||||
|
||||
const onSelectEmbedding = useCallback(
|
||||
const onSelect = useCallback(
|
||||
(v: string) => {
|
||||
insertEmbedding(v);
|
||||
handleClose();
|
||||
insertTrigger(v);
|
||||
handleClosePopover();
|
||||
},
|
||||
[handleClose, insertEmbedding]
|
||||
[handleClosePopover, insertTrigger]
|
||||
);
|
||||
|
||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
isOpen,
|
||||
onClose,
|
||||
onOpen,
|
||||
onSelectEmbedding,
|
||||
onSelect,
|
||||
onKeyDown,
|
||||
onFocus,
|
||||
};
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
useHotkeys('alt+a', onFocus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
@@ -24,49 +23,33 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelMetadataArg = {
|
||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
type DeleteModelArg = {
|
||||
key: string;
|
||||
};
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
type DeleteModelResponse = void;
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type InstallModelArg = {
|
||||
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
||||
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
||||
config?: JSONObject;
|
||||
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
|
||||
};
|
||||
|
||||
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type ListImportModelsResponse =
|
||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||
type ListModelInstallsResponse =
|
||||
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type DeleteImportModelsResponse =
|
||||
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||
type CancelModelInstallResponse =
|
||||
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||
|
||||
type PruneModelImportsResponse =
|
||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||
type PruneCompletedModelInstallsResponse =
|
||||
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
@@ -83,6 +66,7 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@@ -102,6 +86,10 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@@ -146,31 +134,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return buildModelsUrl(`?${query}`);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||
query: (key) => {
|
||||
return buildModelsUrl(`i/${key}/metadata`);
|
||||
},
|
||||
providesTags: ['Model'],
|
||||
}),
|
||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
@@ -180,28 +144,17 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/metadata`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
query: ({ source }) => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
params: { source, access_token },
|
||||
params: { source },
|
||||
method: 'POST',
|
||||
body: config,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
invalidatesTags: ['Model', 'ModelInstalls'],
|
||||
}),
|
||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
|
||||
query: ({ key }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
@@ -210,7 +163,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
||||
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
url: buildModelsUrl(`convert/${key}`),
|
||||
@@ -253,6 +206,57 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||
return {
|
||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
};
|
||||
},
|
||||
providesTags: ['ModelInstalls'],
|
||||
}),
|
||||
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
|
||||
query: (id) => {
|
||||
return {
|
||||
url: buildModelsUrl(`install/${id}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return buildModelsUrl(`?${query}`);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
@@ -313,40 +317,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
},
|
||||
}),
|
||||
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||
return {
|
||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
getModelImports: build.query<ListImportModelsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl(`import`),
|
||||
};
|
||||
},
|
||||
providesTags: ['ModelImports'],
|
||||
}),
|
||||
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
|
||||
query: (id) => {
|
||||
return {
|
||||
url: buildModelsUrl(`import/${id}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelImports'],
|
||||
}),
|
||||
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('import'),
|
||||
method: 'PATCH',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelImports'],
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -360,16 +330,14 @@ export const {
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useUpdateModelMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
useLazyScanModelsQuery,
|
||||
useGetModelImportsQuery,
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
usePruneModelImportsMutation,
|
||||
useUpdateModelMetadataMutation,
|
||||
useLazyScanFolderQuery,
|
||||
useListModelInstallsQuery,
|
||||
useCancelModelInstallMutation,
|
||||
usePruneCompletedModelInstallsMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
|
||||
@@ -28,7 +28,7 @@ export const tagTypes = [
|
||||
'InvocationCacheStatus',
|
||||
'Model',
|
||||
'ModelConfig',
|
||||
'ModelImports',
|
||||
'ModelInstalls',
|
||||
'T2IAdapterModel',
|
||||
'MainModel',
|
||||
'VaeModel',
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
|
||||
// Model Configs
|
||||
|
||||
// TODO(MM2): Can we make key required in the pydantic model?
|
||||
export type LoRAModelConfig = S['LoRAConfig'];
|
||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
|
||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||
// TODO(MM2): Can we rename this to T2IAdapterConfig
|
||||
export type T2IAdapterModelConfig = S['T2IConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user