mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 07:17:57 -05:00
Compare commits
247 Commits
improve-co
...
v4.2.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fca62680d | ||
|
|
f76282a5ff | ||
|
|
9a3b8c6fcb | ||
|
|
bd74b84cc5 | ||
|
|
dc23bebebf | ||
|
|
38b6f90c02 | ||
|
|
cd9dfefe3c | ||
|
|
b9946e50f9 | ||
|
|
06f49a30f6 | ||
|
|
e1af78c702 | ||
|
|
c5588e1ff7 | ||
|
|
07ac292680 | ||
|
|
7c032ea604 | ||
|
|
c5ee415607 | ||
|
|
fa40061eca | ||
|
|
7cafd78d6e | ||
|
|
8a43656cf9 | ||
|
|
bd3b6ca11b | ||
|
|
ceae5fe1db | ||
|
|
25067e4f0d | ||
|
|
fb0aaa3e6d | ||
|
|
c22526b9d0 | ||
|
|
c881882f73 | ||
|
|
36473fc52a | ||
|
|
b9964ecc4a | ||
|
|
051af802fe | ||
|
|
3ff2e558d9 | ||
|
|
fc187c9253 | ||
|
|
605f460c7d | ||
|
|
60d1e686d8 | ||
|
|
22704dd542 | ||
|
|
875673c9ba | ||
|
|
f604575862 | ||
|
|
80a67572f1 | ||
|
|
60ac937698 | ||
|
|
1e41949a02 | ||
|
|
5f0e330ed2 | ||
|
|
9dd779b414 | ||
|
|
fa183025ac | ||
|
|
d3c85aa91a | ||
|
|
82619602a5 | ||
|
|
196f3b721d | ||
|
|
244c28859d | ||
|
|
40ae174c41 | ||
|
|
afaebdf151 | ||
|
|
d661517d94 | ||
|
|
82a69a54ac | ||
|
|
ffc28176fe | ||
|
|
230e205541 | ||
|
|
7e94350351 | ||
|
|
c4e8549c73 | ||
|
|
350a210835 | ||
|
|
ed781dbb0c | ||
|
|
b41ea963e7 | ||
|
|
da5d105049 | ||
|
|
5301770525 | ||
|
|
d08e405017 | ||
|
|
534640ccde | ||
|
|
d5ab8cab5c | ||
|
|
4767301ad3 | ||
|
|
21d7ca45e6 | ||
|
|
020e8eb413 | ||
|
|
3d49541c09 | ||
|
|
1ef266845a | ||
|
|
a37589ca5f | ||
|
|
171a505f5e | ||
|
|
8004a0d5f5 | ||
|
|
610a1fd611 | ||
|
|
43108eec13 | ||
|
|
b03073d888 | ||
|
|
a43d602f16 | ||
|
|
7e9a89f8c6 | ||
|
|
79ceac2f82 | ||
|
|
8e47e005a7 | ||
|
|
d13aafb514 | ||
|
|
63a7e19dbf | ||
|
|
fbc5a8ec65 | ||
|
|
8ce6e4540e | ||
|
|
f14f377ede | ||
|
|
1925f83f5e | ||
|
|
3a5ad6d112 | ||
|
|
41a6bb45f3 | ||
|
|
70e40fa6c1 | ||
|
|
e26125b734 | ||
|
|
cd70937b7f | ||
|
|
f002bca2fa | ||
|
|
56771de856 | ||
|
|
c11478a94a | ||
|
|
fb694b3e17 | ||
|
|
1bc98abc76 | ||
|
|
7f03b04b2f | ||
|
|
4029972530 | ||
|
|
328f160e88 | ||
|
|
aae318425d | ||
|
|
785bb1d9e4 | ||
|
|
a3cb5da130 | ||
|
|
568a4844f7 | ||
|
|
b1e56e2485 | ||
|
|
9432336e2b | ||
|
|
7d19af2caa | ||
|
|
0dbec3ad8b | ||
|
|
52c0c4a32f | ||
|
|
8f1afc032a | ||
|
|
854bca668a | ||
|
|
fea9013cad | ||
|
|
045caddee1 | ||
|
|
58697141bf | ||
|
|
5e419dbb56 | ||
|
|
595096bdcf | ||
|
|
ed03d281e6 | ||
|
|
0b37496c57 | ||
|
|
fde58ce0a3 | ||
|
|
dc134935c8 | ||
|
|
9f9379682e | ||
|
|
f81b8bc9f6 | ||
|
|
6d067e56f2 | ||
|
|
2871676f79 | ||
|
|
1c5c3cdbd6 | ||
|
|
3db69af220 | ||
|
|
1823e446ac | ||
|
|
311e44ad19 | ||
|
|
848ca79da8 | ||
|
|
9cba0dfac9 | ||
|
|
37b1f21bcf | ||
|
|
b2e005f6b5 | ||
|
|
52aac954c0 | ||
|
|
ff01ceae99 | ||
|
|
669d92d8db | ||
|
|
2903060154 | ||
|
|
4af8699a00 | ||
|
|
71fedd1a07 | ||
|
|
6bb1189c88 | ||
|
|
c7546bc82e | ||
|
|
14372e3818 | ||
|
|
64523c4b1b | ||
|
|
89a764a359 | ||
|
|
756108f6bd | ||
|
|
68d628dc14 | ||
|
|
93c9852142 | ||
|
|
493f81788c | ||
|
|
f13427e3f4 | ||
|
|
e28737fc8b | ||
|
|
7391c126d3 | ||
|
|
1c59fce6ad | ||
|
|
a9962fd104 | ||
|
|
e7513f6088 | ||
|
|
c7f22b6a3b | ||
|
|
99413256ce | ||
|
|
aa9695e377 | ||
|
|
c58ac1e80d | ||
|
|
6cc6a45274 | ||
|
|
521f907f58 | ||
|
|
ccdecf21a3 | ||
|
|
b124440023 | ||
|
|
e3a70e598e | ||
|
|
132bbf330a | ||
|
|
2276f327e5 | ||
|
|
6b24424727 | ||
|
|
7153d846a9 | ||
|
|
9a0b77ad38 | ||
|
|
220d45967e | ||
|
|
038a482ef0 | ||
|
|
c325ad3432 | ||
|
|
449bc4dbe5 | ||
|
|
34d68a3663 | ||
|
|
8bb9571485 | ||
|
|
08bcc71e99 | ||
|
|
ff2b2fad83 | ||
|
|
0f0a6852f1 | ||
|
|
745140fa6b | ||
|
|
405fc46888 | ||
|
|
ca728ca29f | ||
|
|
d0fca53e67 | ||
|
|
ad9740d72d | ||
|
|
1c9c982b63 | ||
|
|
3cfd2755c2 | ||
|
|
8ea4067f83 | ||
|
|
940de6a5c5 | ||
|
|
dd74e89127 | ||
|
|
69da67e920 | ||
|
|
76b1f241d7 | ||
|
|
0e5336d8fa | ||
|
|
3501636018 | ||
|
|
e4ce188500 | ||
|
|
e976571fba | ||
|
|
0da36c1238 | ||
|
|
4ef8cbd9d0 | ||
|
|
8f8ddd620b | ||
|
|
1af53aed60 | ||
|
|
7a4bbd092e | ||
|
|
72bbcb2d94 | ||
|
|
c2eef93476 | ||
|
|
cfb12615e1 | ||
|
|
a983f27aad | ||
|
|
7cb32d3d83 | ||
|
|
ac56ab79a7 | ||
|
|
50d3030471 | ||
|
|
5beec8211a | ||
|
|
5a4d10467b | ||
|
|
7590f3005e | ||
|
|
2f9ebdec69 | ||
|
|
e257a72f94 | ||
|
|
843f82c837 | ||
|
|
66858effa2 | ||
|
|
21a60af881 | ||
|
|
ead1748c54 | ||
|
|
cd12ca6e85 | ||
|
|
34e1eb19f9 | ||
|
|
987ee704a1 | ||
|
|
e77c7e40b7 | ||
|
|
8aebc29b91 | ||
|
|
d968c6f379 | ||
|
|
2dae5eb7ad | ||
|
|
911a24479b | ||
|
|
f29c406fed | ||
|
|
287c679f7b | ||
|
|
0bf14c2830 | ||
|
|
b48d4a049d | ||
|
|
f211c95dbc | ||
|
|
8e5e9b53d6 | ||
|
|
e9a20051bd | ||
|
|
38df6f3702 | ||
|
|
3b64e7a1fd | ||
|
|
49c84cd423 | ||
|
|
1fe90c357c | ||
|
|
fcb071f30c | ||
|
|
57c831442e | ||
|
|
f65c7e2bfd | ||
|
|
7c39929758 | ||
|
|
a26667d3ca | ||
|
|
bb04f496e0 | ||
|
|
70903ef057 | ||
|
|
d72f272f16 | ||
|
|
34cdfc61ab | ||
|
|
470a39935c | ||
|
|
f1e79d5a8f | ||
|
|
f055e1edb6 | ||
|
|
fa6efac436 | ||
|
|
3ead827d61 | ||
|
|
c140d3b1df | ||
|
|
34438ce1af | ||
|
|
3ddd7ced49 | ||
|
|
41b909cbe3 | ||
|
|
3a26c7bb9e | ||
|
|
df5ebdbc4f | ||
|
|
af1b57a01f | ||
|
|
9cc1f20ad5 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -188,4 +188,3 @@ installer/install.sh
|
|||||||
installer/update.bat
|
installer/update.bat
|
||||||
installer/update.sh
|
installer/update.sh
|
||||||
installer/InvokeAI-Installer/
|
installer/InvokeAI-Installer/
|
||||||
.aider*
|
|
||||||
|
|||||||
4
Makefile
4
Makefile
@@ -18,6 +18,7 @@ help:
|
|||||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||||
@echo "installer-zip Build the installer .zip file for the current version"
|
@echo "installer-zip Build the installer .zip file for the current version"
|
||||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||||
|
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||||
|
|
||||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||||
ruff:
|
ruff:
|
||||||
@@ -70,3 +71,6 @@ installer-zip:
|
|||||||
tag-release:
|
tag-release:
|
||||||
cd installer && ./tag_release.sh
|
cd installer && ./tag_release.sh
|
||||||
|
|
||||||
|
# Generate the OpenAPI Schema for the app
|
||||||
|
openapi:
|
||||||
|
python scripts/generate_openapi_schema.py
|
||||||
|
|||||||
@@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
|||||||
specify the source and destination of the download, and keep track of
|
specify the source and destination of the download, and keep track of
|
||||||
the progress of the download.
|
the progress of the download.
|
||||||
|
|
||||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
Two job types are defined. `DownloadJob` and
|
||||||
|
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||||
following fields:
|
following fields:
|
||||||
|
|
||||||
| **Field** | **Type** | **Default** | **Description** |
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
@@ -138,7 +139,7 @@ following fields:
|
|||||||
| `dest` | Path | | Where to download to |
|
| `dest` | Path | | Where to download to |
|
||||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
@@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
|||||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||||
the job's `cancelled` property will be set to True.
|
the job's `cancelled` property will be set to True.
|
||||||
|
|
||||||
|
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||||
|
which contain multiple files and directories under a common root:
|
||||||
|
|
||||||
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
|
|----------------|-----------------|---------------|-----------------|
|
||||||
|
| _Fields passed in at job creation time_ |
|
||||||
|
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||||
|
| `dest` | Path | | Where to download to |
|
||||||
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
|
| _Fields updated over the course of the download task_
|
||||||
|
| `status` | DownloadJobStatus| | Status code |
|
||||||
|
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||||
|
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||||
|
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||||
|
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||||
|
| `error` | str | | String version of the traceback associated with an error |
|
||||||
|
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||||
|
|
||||||
|
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||||
|
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||||
|
these from the individual download jobs in `download_parts`.
|
||||||
|
|
||||||
|
|
||||||
### Callbacks
|
### Callbacks
|
||||||
|
|
||||||
Download jobs can be associated with a series of callbacks, each with
|
Download jobs can be associated with a series of callbacks, each with
|
||||||
@@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
|||||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||||
with `join()`.
|
with `join()`.
|
||||||
|
|
||||||
#### job = queue.download(source, dest, priority, access_token)
|
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
Create a new download job and put it on the queue, returning the
|
Create a new download job and put it on the queue, returning the
|
||||||
DownloadJob object.
|
DownloadJob object.
|
||||||
|
|
||||||
|
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
|
This is similar to download(), but instead of taking a single source,
|
||||||
|
it accepts a `parts` argument consisting of a list of
|
||||||
|
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||||
|
where the URL is the location of the remote file, and the Path is the
|
||||||
|
destination.
|
||||||
|
|
||||||
|
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||||
|
consists of a url/path pair. Note that the path *must* be relative.
|
||||||
|
|
||||||
|
The method returns a `MultiFileDownloadJob`.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||||
|
path='my_model/textencoder/pytorch_model.safetensors'
|
||||||
|
)
|
||||||
|
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||||
|
path='my_model/vae/diffusers_model.safetensors'
|
||||||
|
)
|
||||||
|
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||||
|
dest='/tmp/downloads',
|
||||||
|
on_progress=TqdmProgress().update)
|
||||||
|
queue.wait_for_job(job)
|
||||||
|
print(f"The files were downloaded to {job.download_path}")
|
||||||
|
```
|
||||||
|
|
||||||
#### jobs = queue.list_jobs()
|
#### jobs = queue.list_jobs()
|
||||||
|
|
||||||
Return a list of all active and inactive `DownloadJob`s.
|
Return a list of all active and inactive `DownloadJob`s.
|
||||||
|
|||||||
@@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
|||||||
following initialization pattern:
|
following initialization pattern:
|
||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||||
from invokeai.app.services.model_install import ModelInstallService
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = get_config()
|
||||||
config.parse_args()
|
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = SqliteDatabase(config, logger)
|
db = SqliteDatabase(config.db_path, logger)
|
||||||
record_store = ModelRecordServiceSQL(db)
|
record_store = ModelRecordServiceSQL(db)
|
||||||
queue = DownloadQueueService()
|
queue = DownloadQueueService()
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
installer = ModelInstallService(app_config=config,
|
installer = ModelInstallService(app_config=config,
|
||||||
record_store=record_store,
|
record_store=record_store,
|
||||||
download_queue=queue
|
download_queue=queue
|
||||||
)
|
)
|
||||||
installer.start()
|
installer.start()
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1367,12 +1366,20 @@ the in-memory loaded model:
|
|||||||
| `model` | AnyModel | The instantiated model (details below) |
|
| `model` | AnyModel | The instantiated model (details below) |
|
||||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||||
|
|
||||||
Because the loader can return multiple model types, it is typed to
|
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
|
||||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
The `get_model_by_key()` method will retrieve the model using its
|
||||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
unique database key. For example:
|
||||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
|
||||||
models. The others are obvious.
|
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
|
|
||||||
|
`get_model_by_key()` may raise any of the following exceptions:
|
||||||
|
|
||||||
|
* `UnknownModelException` -- key not in database
|
||||||
|
* `ModelNotFoundException` -- key in database but model not found at path
|
||||||
|
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||||
|
|
||||||
|
### Using the Loaded Model in Inference
|
||||||
|
|
||||||
`LoadedModel` acts as a context manager. The context loads the model
|
`LoadedModel` acts as a context manager. The context loads the model
|
||||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||||
@@ -1380,17 +1387,33 @@ in the execution device for the duration of the context, and returns
|
|||||||
the model. Use it like this:
|
the model. Use it like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
with model_info as vae:
|
with loaded_model as vae:
|
||||||
image = vae.decode(latents)[0]
|
image = vae.decode(latents)[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
`get_model_by_key()` may raise any of the following exceptions:
|
The object returned by the LoadedModel context manager is an
|
||||||
|
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
|
||||||
|
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||||
|
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||||
|
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||||
|
models. The others are obvious.
|
||||||
|
|
||||||
|
In addition, you may call `LoadedModel.model_on_device()`, a context
|
||||||
|
manager that returns a tuple of the model's state dict in CPU and the
|
||||||
|
model itself in VRAM. It is used to optimize the LoRA patching and
|
||||||
|
unpatching process:
|
||||||
|
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
|
with loaded_model.model_on_device() as (state_dict, vae):
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Since not all models have state dicts, the `state_dict` return value
|
||||||
|
can be None.
|
||||||
|
|
||||||
|
|
||||||
* `UnknownModelException` -- key not in database
|
|
||||||
* `ModelNotFoundException` -- key in database but model not found at path
|
|
||||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
|
||||||
|
|
||||||
### Emitting model loading events
|
### Emitting model loading events
|
||||||
|
|
||||||
When the `context` argument is passed to `load_model_*()`, it will
|
When the `context` argument is passed to `load_model_*()`, it will
|
||||||
@@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the
|
|||||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||||
model configuration to `load_model_by_config()`. It may raise a
|
model configuration to `load_model_by_config()`. It may raise a
|
||||||
`NotImplementedException`.
|
`NotImplementedException`.
|
||||||
|
|
||||||
|
## Invocation Context Model Manager API
|
||||||
|
|
||||||
|
Within invocations, the following methods are available from the
|
||||||
|
`InvocationContext` object:
|
||||||
|
|
||||||
|
### context.download_and_cache_model(source) -> Path
|
||||||
|
|
||||||
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
|
it locally, and then returns a Path to the local model. The source can
|
||||||
|
be a direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method loads a local model from the indicated path, returning a
|
||||||
|
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||||
|
to the object, and returns a `AnyModel` object. If no loader is
|
||||||
|
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||||
|
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||||
|
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||||
|
like a diffusers directory.
|
||||||
|
|
||||||
|
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
|
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||||
|
direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
|
|||||||
|
|
||||||
Check the [configuration docs] for more detail about the settings and how to specify them.
|
Check the [configuration docs] for more detail about the settings and how to specify them.
|
||||||
|
|
||||||
|
## `ModuleNotFoundError: No module named 'controlnet_aux'`
|
||||||
|
|
||||||
|
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
|
||||||
|
|
||||||
|
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
|
||||||
|
|
||||||
|
- Run the Invoke launcher
|
||||||
|
- Choose the developer console option
|
||||||
|
- Run this command: `pip cache remove controlnet_aux`
|
||||||
|
- Close the terminal window
|
||||||
|
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
|
||||||
|
|
||||||
## Out of Memory Issues
|
## Out of Memory Issues
|
||||||
|
|
||||||
The models are large, VRAM is expensive, and you may find yourself
|
The models are large, VRAM is expensive, and you may find yourself
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class ApiDependencies:
|
|||||||
conditioning = ObjectSerializerForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from copy import deepcopy
|
|||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response, UploadFile
|
from fastapi import Body, Path, Query, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse, HTMLResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||||
@@ -502,6 +502,133 @@ async def install_model(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/install/huggingface",
|
||||||
|
operation_id="install_hugging_face_model",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The model is being installed"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_class=HTMLResponse,
|
||||||
|
)
|
||||||
|
async def install_hugging_face_model(
|
||||||
|
source: str = Query(description="HuggingFace repo_id to install"),
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Install a Hugging Face model using a string identifier."""
|
||||||
|
|
||||||
|
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
|
||||||
|
if message:
|
||||||
|
message = f"<p>{message}</p>"
|
||||||
|
title_class = "error" if is_error else "success"
|
||||||
|
return f"""
|
||||||
|
<html>
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<title>{title}</title>
|
||||||
|
<style>
|
||||||
|
body {{
|
||||||
|
text-align: center;
|
||||||
|
background-color: hsl(220 12% 10% / 1);
|
||||||
|
font-family: Helvetica, sans-serif;
|
||||||
|
color: hsl(220 12% 86% / 1);
|
||||||
|
}}
|
||||||
|
|
||||||
|
.repo-id {{
|
||||||
|
color: hsl(220 12% 68% / 1);
|
||||||
|
}}
|
||||||
|
|
||||||
|
.error {{
|
||||||
|
color: hsl(0 42% 68% / 1)
|
||||||
|
}}
|
||||||
|
|
||||||
|
.message-box {{
|
||||||
|
display: inline-block;
|
||||||
|
border-radius: 5px;
|
||||||
|
background-color: hsl(220 12% 20% / 1);
|
||||||
|
padding-inline-end: 30px;
|
||||||
|
padding: 20px;
|
||||||
|
padding-inline-start: 30px;
|
||||||
|
padding-inline-end: 30px;
|
||||||
|
}}
|
||||||
|
|
||||||
|
.container {{
|
||||||
|
display: flex;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}}
|
||||||
|
|
||||||
|
a {{
|
||||||
|
color: inherit
|
||||||
|
}}
|
||||||
|
|
||||||
|
a:visited {{
|
||||||
|
color: inherit
|
||||||
|
}}
|
||||||
|
|
||||||
|
a:active {{
|
||||||
|
color: inherit
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body style="background-color: hsl(220 12% 10% / 1);">
|
||||||
|
<div class="container">
|
||||||
|
<div class="message-box">
|
||||||
|
<h2 class="{title_class}">{heading}</h2>
|
||||||
|
{message}
|
||||||
|
<p class="repo-id">Repo ID: {repo_id}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = HuggingFaceMetadataFetch().from_id(source)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
except UnknownMetadataException:
|
||||||
|
title = "Unable to Install Model"
|
||||||
|
heading = "No HuggingFace repository found with that repo ID."
|
||||||
|
message = "Ensure the repo ID is correct and try again."
|
||||||
|
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
|
if metadata.is_diffusers:
|
||||||
|
installer.heuristic_import(
|
||||||
|
source=source,
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
|
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
|
||||||
|
installer.heuristic_import(
|
||||||
|
source=str(metadata.ckpt_urls[0]),
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
title = "Unable to Install Model"
|
||||||
|
heading = "This HuggingFace repo has multiple models."
|
||||||
|
message = "Please use the Model Manager to install this model."
|
||||||
|
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
|
||||||
|
|
||||||
|
title = "Model Install Started"
|
||||||
|
heading = "Your HuggingFace model is installing now."
|
||||||
|
message = "You can close this tab and check the Model Manager for installation progress."
|
||||||
|
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
title = "Unable to Install Model"
|
||||||
|
heading = "There was an problem installing this model."
|
||||||
|
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
|
||||||
|
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/install",
|
"/install",
|
||||||
operation_id="list_model_installs",
|
operation_id="list_model_installs",
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ import logging
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from inspect import signature
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -13,11 +11,9 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.json_schema import models_json_schema
|
|
||||||
from torch.backends.mps import is_available as is_mps_available
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
# for PyCharm:
|
# for PyCharm:
|
||||||
@@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
|
|||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.events.events_common import EventBase
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
@@ -45,11 +39,6 @@ from .api.routers import (
|
|||||||
workflows,
|
workflows,
|
||||||
)
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
UIConfigBase,
|
|
||||||
)
|
|
||||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
|
|||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
app.openapi = get_openapi_func(app)
|
||||||
# Build a custom OpenAPI to include all outputs
|
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
|
||||||
def custom_openapi() -> dict[str, Any]:
|
|
||||||
if app.openapi_schema:
|
|
||||||
return app.openapi_schema
|
|
||||||
openapi_schema = get_openapi(
|
|
||||||
title=app.title,
|
|
||||||
description="An API for invoking AI image operations",
|
|
||||||
version="1.0.0",
|
|
||||||
routes=app.routes,
|
|
||||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add all outputs
|
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
|
||||||
output_types = set()
|
|
||||||
output_type_titles = {}
|
|
||||||
for invoker in all_invocations:
|
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
|
||||||
output_types.add(output_type)
|
|
||||||
|
|
||||||
output_schemas = models_json_schema(
|
|
||||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
|
||||||
)
|
|
||||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
|
||||||
# This could break in some cases, figure out a better way to do it
|
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
|
||||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
|
||||||
|
|
||||||
# Some models don't end up in the schemas as standalone definitions
|
|
||||||
additional_schemas = models_json_schema(
|
|
||||||
[
|
|
||||||
(UIConfigBase, "serialization"),
|
|
||||||
(InputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(ModelIdentifierField, "serialization"),
|
|
||||||
(ProgressImage, "serialization"),
|
|
||||||
],
|
|
||||||
ref_template="#/components/schemas/{model}",
|
|
||||||
)
|
|
||||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
|
||||||
for invoker in all_invocations:
|
|
||||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
|
||||||
output_type = signature(obj=invoker.invoke).return_annotation
|
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
|
||||||
invoker_schema["output"] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
|
||||||
invoker_schema["class"] = "invocation"
|
|
||||||
|
|
||||||
# Add all event schemas
|
|
||||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
|
||||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
|
||||||
if "$defs" in json_schema:
|
|
||||||
for schema_key, schema in json_schema["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
|
||||||
del json_schema["$defs"]
|
|
||||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
|
||||||
return app.openapi_schema
|
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
|||||||
@@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
"""Registers an invocation output."""
|
"""Registers an invocation output."""
|
||||||
cls._output_classes.add(output)
|
cls._output_classes.add(output)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||||
@@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationOutputsUnion = TypeAliasType(
|
AnyInvocationOutput = TypeAliasType(
|
||||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "output"
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_type(cls) -> str:
|
def get_type(cls) -> str:
|
||||||
@@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||||
"""Registers an invocation."""
|
"""Registers an invocation."""
|
||||||
cls._invocation_classes.add(invocation)
|
cls._invocation_classes.add(invocation)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationsUnion = TypeAliasType(
|
AnyInvocation = TypeAliasType(
|
||||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||||
if uiconfig is not None:
|
if uiconfig is not None:
|
||||||
@@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "invocation"
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=False,
|
||||||
coerce_numbers_to_str=True,
|
coerce_numbers_to_str=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
98
invokeai/app/invocations/blend_latents.py
Normal file
98
invokeai/app/invocations/blend_latents.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"lblend",
|
||||||
|
title="Blend Latents",
|
||||||
|
tags=["latents", "blend"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.3",
|
||||||
|
)
|
||||||
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
latents_a: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
latents_b: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||||
|
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||||
|
|
||||||
|
if latents_a.shape != latents_b.shape:
|
||||||
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
|
def slerp(
|
||||||
|
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||||
|
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||||
|
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||||
|
DOT_THRESHOLD: float = 0.9995,
|
||||||
|
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
Args:
|
||||||
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||||
|
v0 (np.ndarray): Starting vector
|
||||||
|
v1 (np.ndarray): Final vector
|
||||||
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||||
|
colineal. Not recommended to alter this.
|
||||||
|
Returns:
|
||||||
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||||
|
"""
|
||||||
|
inputs_are_torch = False
|
||||||
|
if not isinstance(v0, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
if not isinstance(v1, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
if inputs_are_torch:
|
||||||
|
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||||
|
return v2_torch
|
||||||
|
else:
|
||||||
|
assert isinstance(v2, np.ndarray)
|
||||||
|
return v2
|
||||||
|
|
||||||
|
# blend
|
||||||
|
bl = slerp(self.alpha, latents_a, latents_b)
|
||||||
|
assert isinstance(bl, torch.Tensor)
|
||||||
|
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
blended_latents = blended_latents.to("cpu")
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
||||||
@@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
|
text_encoder,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
@@ -172,9 +176,14 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(
|
||||||
|
text_encoder,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
prefix=lora_prefix,
|
||||||
|
model_state_dict=state_dict,
|
||||||
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
LATENT_SCALE_FACTOR = 8
|
LATENT_SCALE_FACTOR = 8
|
||||||
"""
|
"""
|
||||||
@@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
|||||||
|
|
||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
"""A literal type for PIL image modes supported by Invoke"""
|
"""A literal type for PIL image modes supported by Invoke"""
|
||||||
|
|
||||||
|
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||||
from invokeai.backend.image_util.canny import get_canny_edges
|
from invokeai.backend.image_util.canny import get_canny_edges
|
||||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||||
from invokeai.backend.image_util.hed import HEDProcessor
|
from invokeai.backend.image_util.hed import HEDProcessor
|
||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||||
|
|
||||||
@@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return context.images.get_pil(self.image.image_name, "RGB")
|
return context.images.get_pil(self.image.image_name, "RGB")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
self._context = context
|
||||||
raw_image = self.load_image(context)
|
raw_image = self.load_image(context)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
@@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
|
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(
|
||||||
image,
|
image,
|
||||||
@@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = normalbae_processor(
|
processed_image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
@@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image,
|
image,
|
||||||
@@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image,
|
image,
|
||||||
@@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(
|
processed_image = content_shuffle_processor(
|
||||||
image,
|
image,
|
||||||
@@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
@@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(
|
processed_image = mediapipe_face_processor(
|
||||||
image,
|
image,
|
||||||
@@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image,
|
||||||
@@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
def run_processor(self, img):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(
|
processed_np_image = self.tile_resample(
|
||||||
np_img,
|
np_img,
|
||||||
# res=self.tile_size,
|
# res=self.tile_size,
|
||||||
@@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||||
@@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = np_image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
@@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
depth_anything_detector = DepthAnythingDetector()
|
def loader(model_path: Path):
|
||||||
depth_anything_detector.load_model(model_size=self.model_size)
|
return DepthAnythingDetector.load_model(
|
||||||
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
|
)
|
||||||
|
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
with self._context.models.load_remote_model(
|
||||||
return processed_image
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||||
|
) as model:
|
||||||
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
draw_hands: bool = InputField(default=False)
|
draw_hands: bool = InputField(default=False)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
dw_openpose = DWOpenposeDetector()
|
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
|
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
|
|
||||||
|
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
processed_image = dw_openpose(
|
processed_image = dw_openpose(
|
||||||
image,
|
image,
|
||||||
draw_face=self.draw_face,
|
draw_face=self.draw_face,
|
||||||
|
|||||||
80
invokeai/app/invocations/create_denoise_mask.py
Normal file
80
invokeai/app/invocations/create_denoise_mask.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
|
||||||
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"create_denoise_mask",
|
||||||
|
title="Create Denoise Mask",
|
||||||
|
tags=["mask", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.2",
|
||||||
|
)
|
||||||
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
|
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||||
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||||
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == torch.float32,
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||||
|
if mask_image.mode != "L":
|
||||||
|
mask_image = mask_image.convert("L")
|
||||||
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
|
if mask_tensor.dim() == 3:
|
||||||
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||||||
|
# if shape is not None:
|
||||||
|
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||||
|
if self.image is not None:
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
image_tensor = None
|
||||||
|
|
||||||
|
mask = self.prep_mask_tensor(
|
||||||
|
context.images.get_pil(self.mask.image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_tensor is not None:
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
|
# TODO:
|
||||||
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
|
|
||||||
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
|
else:
|
||||||
|
masked_latents_name = None
|
||||||
|
|
||||||
|
mask_name = context.tensors.save(tensor=mask)
|
||||||
|
|
||||||
|
return DenoiseMaskOutput.build(
|
||||||
|
mask_name=mask_name,
|
||||||
|
masked_latents_name=masked_latents_name,
|
||||||
|
gradient=False,
|
||||||
|
)
|
||||||
138
invokeai/app/invocations/create_gradient_mask.py
Normal file
138
invokeai/app/invocations/create_gradient_mask.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||||
|
from invokeai.app.invocations.model import UNetField, VAEField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
|
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("gradient_mask_output")
|
||||||
|
class GradientMaskOutput(BaseInvocationOutput):
|
||||||
|
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||||
|
|
||||||
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
expanded_mask_area: ImageField = OutputField(
|
||||||
|
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"create_gradient_mask",
|
||||||
|
title="Create Gradient Mask",
|
||||||
|
tags=["mask", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.1.0",
|
||||||
|
)
|
||||||
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
|
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
|
edge_radius: int = InputField(
|
||||||
|
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||||
|
)
|
||||||
|
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||||
|
minimum_denoise: float = InputField(
|
||||||
|
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||||
|
)
|
||||||
|
image: Optional[ImageField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] Image",
|
||||||
|
ui_order=6,
|
||||||
|
)
|
||||||
|
unet: Optional[UNetField] = InputField(
|
||||||
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="[OPTIONAL] UNet",
|
||||||
|
ui_order=5,
|
||||||
|
)
|
||||||
|
vae: Optional[VAEField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] VAE",
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||||
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == torch.float32,
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||||
|
if self.edge_radius > 0:
|
||||||
|
if self.coherence_mode == "Box Blur":
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||||
|
else: # Gaussian Blur OR Staged
|
||||||
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||||
|
|
||||||
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||||
|
|
||||||
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||||
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
|
|
||||||
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
if self.coherence_mode == "Staged":
|
||||||
|
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||||
|
else:
|
||||||
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||||
|
|
||||||
|
else:
|
||||||
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
|
|
||||||
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||||
|
|
||||||
|
# compute a [0, 1] mask from the blur_tensor
|
||||||
|
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||||
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||||
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
|
masked_latents_name = None
|
||||||
|
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||||
|
# all three fields must be present at the same time
|
||||||
|
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
assert isinstance(main_model_config, MainConfigBase)
|
||||||
|
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||||
|
mask = blur_tensor
|
||||||
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
|
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||||
|
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||||
|
)
|
||||||
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
|
|
||||||
|
return GradientMaskOutput(
|
||||||
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||||
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
|
)
|
||||||
61
invokeai/app/invocations/crop_latents.py
Normal file
61
invokeai/app/invocations/crop_latents.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||||
|
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
|
||||||
|
@invocation(
|
||||||
|
"crop_latents",
|
||||||
|
title="Crop Latents",
|
||||||
|
tags=["latents", "crop"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.2",
|
||||||
|
)
|
||||||
|
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||||
|
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||||
|
class CropLatentsCoreInvocation(BaseInvocation):
|
||||||
|
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
|
||||||
|
divisible by the latent scale factor of 8.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
x: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
y: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
width: int = InputField(
|
||||||
|
ge=1,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
ge=1,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
|
x1 = self.x // LATENT_SCALE_FACTOR
|
||||||
|
y1 = self.y // LATENT_SCALE_FACTOR
|
||||||
|
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
|
||||||
|
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
|
||||||
|
|
||||||
|
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=cropped_latents)
|
||||||
|
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||||
848
invokeai/app/invocations/denoise_latents.py
Normal file
848
invokeai/app/invocations/denoise_latents.py
Normal file
@@ -0,0 +1,848 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import inspect
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
|
from diffusers.models.adapter import T2IAdapter
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||||
|
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
||||||
|
from pydantic import field_validator
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
from transformers import CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
|
ControlNetData,
|
||||||
|
StableDiffusionGeneratorPipeline,
|
||||||
|
T2IAdapterData,
|
||||||
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
IPAdapterData,
|
||||||
|
Range,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(
|
||||||
|
context: InvocationContext,
|
||||||
|
scheduler_info: ModelIdentifierField,
|
||||||
|
scheduler_name: str,
|
||||||
|
seed: int,
|
||||||
|
) -> Scheduler:
|
||||||
|
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||||
|
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||||
|
# possible.
|
||||||
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
|
orig_scheduler_info = context.models.load(scheduler_info)
|
||||||
|
with orig_scheduler_info as orig_scheduler:
|
||||||
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
|
if "_backup" in scheduler_config:
|
||||||
|
scheduler_config = scheduler_config["_backup"]
|
||||||
|
scheduler_config = {
|
||||||
|
**scheduler_config,
|
||||||
|
**scheduler_extra_config, # FIXME
|
||||||
|
"_backup": scheduler_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||||
|
if scheduler_class is DPMSolverSDEScheduler:
|
||||||
|
scheduler_config["noise_sampler_seed"] = seed
|
||||||
|
|
||||||
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
assert isinstance(scheduler, Scheduler)
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"denoise_latents",
|
||||||
|
title="Denoise Latents",
|
||||||
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
|
category="latents",
|
||||||
|
version="1.5.3",
|
||||||
|
)
|
||||||
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
|
)
|
||||||
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
|
)
|
||||||
|
noise: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.noise,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=3,
|
||||||
|
)
|
||||||
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
|
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||||
|
)
|
||||||
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
|
)
|
||||||
|
unet: UNetField = InputField(
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=5,
|
||||||
|
)
|
||||||
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||||
|
description=FieldDescriptions.ip_adapter,
|
||||||
|
title="IP-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=6,
|
||||||
|
)
|
||||||
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
||||||
|
description=FieldDescriptions.t2i_adapter,
|
||||||
|
title="T2I-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7,
|
||||||
|
)
|
||||||
|
cfg_rescale_multiplier: float = InputField(
|
||||||
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=4,
|
||||||
|
)
|
||||||
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.mask,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("cfg_scale")
|
||||||
|
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||||
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
cond_list: list[ConditioningField],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
||||||
|
|
||||||
|
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
resized_mask = tf(mask)
|
||||||
|
return resized_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat_regional_text_embeddings(
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
|
||||||
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
|
mask = masks[prompt_idx]
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||||
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||||
|
# mask, but we don't enforce this.
|
||||||
|
|
||||||
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||||
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||||
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
if pooled_embedding is None or mask is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None or mask is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(
|
||||||
|
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||||
|
mask, latent_height, latent_width, dtype=dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(
|
||||||
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
|
ranges=embedding_ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return (
|
||||||
|
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
|
||||||
|
regions,
|
||||||
|
)
|
||||||
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_conditioning_data(
|
||||||
|
context: InvocationContext,
|
||||||
|
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
|
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
cfg_scale: float | list[float],
|
||||||
|
steps: int,
|
||||||
|
cfg_rescale_multiplier: float,
|
||||||
|
) -> TextConditioningData:
|
||||||
|
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||||
|
cond_list = positive_conditioning_field
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
uncond_list = negative_conditioning_field
|
||||||
|
if not isinstance(uncond_list, list):
|
||||||
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
|
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
|
cond_list, context, unet.device, unet.dtype
|
||||||
|
)
|
||||||
|
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
|
uncond_list, context, unet.device, unet.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
|
text_conditionings=cond_text_embeddings,
|
||||||
|
masks=cond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cfg_scale, list):
|
||||||
|
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||||
|
|
||||||
|
conditioning_data = TextConditioningData(
|
||||||
|
uncond_text=uncond_text_embedding,
|
||||||
|
cond_text=cond_text_embedding,
|
||||||
|
uncond_regions=uncond_regions,
|
||||||
|
cond_regions=cond_regions,
|
||||||
|
guidance_scale=cfg_scale,
|
||||||
|
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||||
|
)
|
||||||
|
return conditioning_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_pipeline(
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
) -> StableDiffusionGeneratorPipeline:
|
||||||
|
class FakeVae:
|
||||||
|
class FakeVaeConfig:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.block_out_channels = [0]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.config = FakeVae.FakeVaeConfig()
|
||||||
|
|
||||||
|
return StableDiffusionGeneratorPipeline(
|
||||||
|
vae=FakeVae(), # TODO: oh...
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prep_control_data(
|
||||||
|
context: InvocationContext,
|
||||||
|
control_input: ControlField | list[ControlField] | None,
|
||||||
|
latents_shape: List[int],
|
||||||
|
exit_stack: ExitStack,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
) -> list[ControlNetData] | None:
|
||||||
|
# Normalize control_input to a list.
|
||||||
|
control_list: list[ControlField]
|
||||||
|
if isinstance(control_input, ControlField):
|
||||||
|
control_list = [control_input]
|
||||||
|
elif isinstance(control_input, list):
|
||||||
|
control_list = control_input
|
||||||
|
elif control_input is None:
|
||||||
|
control_list = []
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||||
|
|
||||||
|
if len(control_list) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||||
|
_, _, latent_height, latent_width = latents_shape
|
||||||
|
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||||
|
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
|
controlnet_data: list[ControlNetData] = []
|
||||||
|
for control_info in control_list:
|
||||||
|
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||||
|
assert isinstance(control_model, ControlNetModel)
|
||||||
|
|
||||||
|
control_image_field = control_info.image
|
||||||
|
input_image = context.images.get_pil(control_image_field.image_name)
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
# and do real check for classifier_free_guidance?
|
||||||
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
|
control_image = prepare_control_image(
|
||||||
|
image=input_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=control_width_resize,
|
||||||
|
height=control_height_resize,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=control_model.device,
|
||||||
|
dtype=control_model.dtype,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
|
)
|
||||||
|
control_item = ControlNetData(
|
||||||
|
model=control_model,
|
||||||
|
image_tensor=control_image,
|
||||||
|
weight=control_info.control_weight,
|
||||||
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
end_step_percent=control_info.end_step_percent,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
# any resizing needed should currently be happening in prepare_control_image(),
|
||||||
|
# but adding resize_mode to ControlNetData in case needed in the future
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
|
)
|
||||||
|
controlnet_data.append(control_item)
|
||||||
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
|
||||||
|
return controlnet_data
|
||||||
|
|
||||||
|
def prep_ip_adapter_image_prompts(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
ip_adapters: List[IPAdapterField],
|
||||||
|
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||||
|
image_prompts = []
|
||||||
|
for single_ip_adapter in ip_adapters:
|
||||||
|
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||||
|
assert isinstance(ip_adapter_model, IPAdapter)
|
||||||
|
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||||
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
|
single_ipa_image_fields = single_ip_adapter.image
|
||||||
|
if not isinstance(single_ipa_image_fields, list):
|
||||||
|
single_ipa_image_fields = [single_ipa_image_fields]
|
||||||
|
|
||||||
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||||
|
with image_encoder_model_info as image_encoder_model:
|
||||||
|
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||||
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||||
|
single_ipa_images, image_encoder_model
|
||||||
|
)
|
||||||
|
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
|
||||||
|
|
||||||
|
return image_prompts
|
||||||
|
|
||||||
|
def prep_ip_adapter_data(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
ip_adapters: List[IPAdapterField],
|
||||||
|
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
exit_stack: ExitStack,
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Optional[List[IPAdapterData]]:
|
||||||
|
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
|
||||||
|
ip_adapter_data_list = []
|
||||||
|
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
||||||
|
ip_adapters, image_prompts, strict=True
|
||||||
|
):
|
||||||
|
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||||
|
|
||||||
|
mask_field = single_ip_adapter.mask
|
||||||
|
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||||
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
|
|
||||||
|
ip_adapter_data_list.append(
|
||||||
|
IPAdapterData(
|
||||||
|
ip_adapter_model=ip_adapter_model,
|
||||||
|
weight=single_ip_adapter.weight,
|
||||||
|
target_blocks=single_ip_adapter.target_blocks,
|
||||||
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
|
||||||
|
|
||||||
|
def run_t2i_adapters(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||||
|
latents_shape: list[int],
|
||||||
|
do_classifier_free_guidance: bool,
|
||||||
|
) -> Optional[list[T2IAdapterData]]:
|
||||||
|
if t2i_adapter is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
||||||
|
if isinstance(t2i_adapter, T2IAdapterField):
|
||||||
|
t2i_adapter = [t2i_adapter]
|
||||||
|
|
||||||
|
if len(t2i_adapter) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
t2i_adapter_data = []
|
||||||
|
for t2i_adapter_field in t2i_adapter:
|
||||||
|
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||||
|
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||||
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
|
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
||||||
|
max_unet_downscale = 8
|
||||||
|
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
||||||
|
max_unet_downscale = 4
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
||||||
|
|
||||||
|
t2i_adapter_model: T2IAdapter
|
||||||
|
with t2i_adapter_loaded_model as t2i_adapter_model:
|
||||||
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||||
|
|
||||||
|
# Resize the T2I-Adapter input image.
|
||||||
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||||
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||||
|
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||||
|
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||||
|
|
||||||
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||||
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||||
|
# T2I-Adapter model.
|
||||||
|
#
|
||||||
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||||
|
# of the same requirements (e.g. preserving binary masks during resize).
|
||||||
|
t2i_image = prepare_control_image(
|
||||||
|
image=image,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
width=t2i_input_width,
|
||||||
|
height=t2i_input_height,
|
||||||
|
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||||
|
device=t2i_adapter_model.device,
|
||||||
|
dtype=t2i_adapter_model.dtype,
|
||||||
|
resize_mode=t2i_adapter_field.resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_state = t2i_adapter_model(t2i_image)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
for idx, value in enumerate(adapter_state):
|
||||||
|
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
||||||
|
|
||||||
|
t2i_adapter_data.append(
|
||||||
|
T2IAdapterData(
|
||||||
|
adapter_state=adapter_state,
|
||||||
|
weight=t2i_adapter_field.weight,
|
||||||
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||||
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return t2i_adapter_data
|
||||||
|
|
||||||
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||||
|
# TODO: research more for second order schedulers timesteps
|
||||||
|
@staticmethod
|
||||||
|
def init_scheduler(
|
||||||
|
scheduler: Union[Scheduler, ConfigMixin],
|
||||||
|
device: torch.device,
|
||||||
|
steps: int,
|
||||||
|
denoising_start: float,
|
||||||
|
denoising_end: float,
|
||||||
|
seed: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||||
|
assert isinstance(scheduler, ConfigMixin)
|
||||||
|
if scheduler.config.get("cpu_only", False):
|
||||||
|
scheduler.set_timesteps(steps, device="cpu")
|
||||||
|
timesteps = scheduler.timesteps.to(device=device)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(steps, device=device)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
|
# skip greater order timesteps
|
||||||
|
_timesteps = timesteps[:: scheduler.order]
|
||||||
|
|
||||||
|
# get start timestep index
|
||||||
|
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||||
|
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
||||||
|
|
||||||
|
# get end timestep index
|
||||||
|
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||||
|
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
||||||
|
|
||||||
|
# apply order to indexes
|
||||||
|
t_start_idx *= scheduler.order
|
||||||
|
t_end_idx *= scheduler.order
|
||||||
|
|
||||||
|
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
||||||
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
|
|
||||||
|
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||||
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
|
# reproducibility.
|
||||||
|
#
|
||||||
|
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
|
||||||
|
# - DDIMScheduler
|
||||||
|
# - DDPMScheduler
|
||||||
|
# - DPMSolverMultistepScheduler
|
||||||
|
# - EulerAncestralDiscreteScheduler
|
||||||
|
# - EulerDiscreteScheduler
|
||||||
|
# - KDPM2AncestralDiscreteScheduler
|
||||||
|
# - LCMScheduler
|
||||||
|
# - TCDScheduler
|
||||||
|
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
||||||
|
if isinstance(scheduler, TCDScheduler):
|
||||||
|
scheduler_step_kwargs.update({"eta": 1.0})
|
||||||
|
|
||||||
|
return timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
|
def prep_inpaint_mask(
|
||||||
|
self, context: InvocationContext, latents: torch.Tensor
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
||||||
|
if self.denoise_mask is None:
|
||||||
|
return None, None, False
|
||||||
|
|
||||||
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
|
if self.denoise_mask.masked_latents_name is not None:
|
||||||
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||||
|
else:
|
||||||
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_noise_and_latents(
|
||||||
|
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||||
|
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
||||||
|
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
|
||||||
|
function handles preparing these values accordingly.
|
||||||
|
|
||||||
|
Expected workflows:
|
||||||
|
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
|
||||||
|
- Image-to-Image Denoising: `noise` and `latents` are both provided.
|
||||||
|
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||||
|
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||||
|
|
||||||
|
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
|
||||||
|
I haven't considered.
|
||||||
|
"""
|
||||||
|
noise = None
|
||||||
|
if noise_field is not None:
|
||||||
|
noise = context.tensors.load(noise_field.latents_name)
|
||||||
|
|
||||||
|
if latents_field is not None:
|
||||||
|
latents = context.tensors.load(latents_field.latents_name)
|
||||||
|
elif noise is not None:
|
||||||
|
latents = torch.zeros_like(noise)
|
||||||
|
else:
|
||||||
|
raise ValueError("'latents' or 'noise' must be provided!")
|
||||||
|
|
||||||
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
|
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
|
|
||||||
|
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||||
|
seed = 0
|
||||||
|
if noise_field is not None and noise_field.seed is not None:
|
||||||
|
seed = noise_field.seed
|
||||||
|
elif latents_field is not None and latents_field.seed is not None:
|
||||||
|
seed = latents_field.seed
|
||||||
|
else:
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
return seed, noise, latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
|
# below. Investigate whether this is appropriate.
|
||||||
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
|
context,
|
||||||
|
self.t2i_adapter,
|
||||||
|
latents.shape,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_adapters: List[IPAdapterField] = []
|
||||||
|
if self.ip_adapter is not None:
|
||||||
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||||
|
if isinstance(self.ip_adapter, list):
|
||||||
|
ip_adapters = self.ip_adapter
|
||||||
|
else:
|
||||||
|
ip_adapters = [self.ip_adapter]
|
||||||
|
|
||||||
|
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
||||||
|
# a series of image conditioning embeddings. This is being done here rather than in the
|
||||||
|
# big model context below in order to use less VRAM on low-VRAM systems.
|
||||||
|
# The image prompts are then passed to prep_ip_adapter_data().
|
||||||
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||||
|
|
||||||
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
|
ExitStack() as exit_stack,
|
||||||
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_unet(
|
||||||
|
unet,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if masked_latents is not None:
|
||||||
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context,
|
||||||
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
|
unet=unet,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_data = self.prep_control_data(
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
|
latents_shape=latents.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
|
context=context,
|
||||||
|
ip_adapters=ip_adapters,
|
||||||
|
image_prompts=image_prompts,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
device=unet.device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
|
latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
mask=mask,
|
||||||
|
masked_latents=masked_latents,
|
||||||
|
is_gradient_mask=gradient_mask,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=controlnet_data,
|
||||||
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
|
callback=step_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
result_latents = result_latents.to("cpu")
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
65
invokeai/app/invocations/ideal_size.py
Normal file
65
invokeai/app/invocations/ideal_size.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||||
|
from invokeai.app.invocations.model import UNetField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("ideal_size_output")
|
||||||
|
class IdealSizeOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
|
width: int = OutputField(description="The ideal width of the image (in pixels)")
|
||||||
|
height: int = OutputField(description="The ideal height of the image (in pixels)")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"ideal_size",
|
||||||
|
title="Ideal Size",
|
||||||
|
tags=["latents", "math", "ideal_size"],
|
||||||
|
version="1.0.3",
|
||||||
|
)
|
||||||
|
class IdealSizeInvocation(BaseInvocation):
|
||||||
|
"""Calculates the ideal size for generation to avoid duplication"""
|
||||||
|
|
||||||
|
width: int = InputField(default=1024, description="Final image width")
|
||||||
|
height: int = InputField(default=576, description="Final image height")
|
||||||
|
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
|
||||||
|
multiplier: float = InputField(
|
||||||
|
default=1.0,
|
||||||
|
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
|
||||||
|
"initial generation artifacts if too large)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
|
||||||
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
aspect = self.width / self.height
|
||||||
|
dimension: float = 512
|
||||||
|
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||||
|
dimension = 768
|
||||||
|
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||||
|
dimension = 1024
|
||||||
|
dimension = dimension * self.multiplier
|
||||||
|
min_dimension = math.floor(dimension * 0.5)
|
||||||
|
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
|
||||||
|
|
||||||
|
if aspect > 1.0:
|
||||||
|
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||||
|
init_width = init_height * aspect
|
||||||
|
else:
|
||||||
|
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||||
|
init_height = init_width / aspect
|
||||||
|
|
||||||
|
scaled_width, scaled_height = self.trim_to_multiple_of(
|
||||||
|
math.floor(init_width),
|
||||||
|
math.floor(init_height),
|
||||||
|
)
|
||||||
|
|
||||||
|
return IdealSizeOutput(width=scaled_width, height=scaled_height)
|
||||||
125
invokeai/app/invocations/image_to_latents.py
Normal file
125
invokeai/app/invocations/image_to_latents.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
from functools import singledispatchmethod
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor2_0,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"i2l",
|
||||||
|
title="Image to Latents",
|
||||||
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.2",
|
||||||
|
)
|
||||||
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(
|
||||||
|
description="The image to encode",
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, torch.nn.Module)
|
||||||
|
orig_dtype = vae.dtype
|
||||||
|
if upcast:
|
||||||
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
(
|
||||||
|
AttnProcessor2_0,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# to be in float32 which can save lots of memory
|
||||||
|
if use_torch_2_0_or_xformers:
|
||||||
|
vae.post_quant_conv.to(orig_dtype)
|
||||||
|
vae.decoder.conv_in.to(orig_dtype)
|
||||||
|
vae.decoder.mid_block.to(orig_dtype)
|
||||||
|
# else:
|
||||||
|
# latents = latents.float()
|
||||||
|
|
||||||
|
else:
|
||||||
|
vae.to(dtype=torch.float16)
|
||||||
|
# latents = latents.half()
|
||||||
|
|
||||||
|
if tiled:
|
||||||
|
vae.enable_tiling()
|
||||||
|
else:
|
||||||
|
vae.disable_tiling()
|
||||||
|
|
||||||
|
# non_noised_latents_from_image
|
||||||
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||||
|
|
||||||
|
latents = vae.config.scaling_factor * latents
|
||||||
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||||
|
|
||||||
|
latents = latents.to("cpu")
|
||||||
|
name = context.tensors.save(tensor=latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
@singledispatchmethod
|
||||||
|
@staticmethod
|
||||||
|
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
assert isinstance(vae, torch.nn.Module)
|
||||||
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
|
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||||
|
dtype=vae.dtype
|
||||||
|
) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@_encode_to_tensor.register
|
||||||
|
@staticmethod
|
||||||
|
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
assert isinstance(vae, torch.nn.Module)
|
||||||
|
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||||
|
return latents
|
||||||
@@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"""Infill the image with the specified method"""
|
"""Infill the image with the specified method"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
def load_image(self) -> tuple[Image.Image, bool]:
|
||||||
"""Process the image to have an alpha channel before being infilled"""
|
"""Process the image to have an alpha channel before being infilled"""
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = self._context.images.get_pil(self.image.image_name)
|
||||||
has_alpha = True if image.mode == "RGBA" else False
|
has_alpha = True if image.mode == "RGBA" else False
|
||||||
return image, has_alpha
|
return image, has_alpha
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
self._context = context
|
||||||
# Retrieve and process image to be infilled
|
# Retrieve and process image to be infilled
|
||||||
input_image, has_alpha = self.load_image(context)
|
input_image, has_alpha = self.load_image()
|
||||||
|
|
||||||
# If the input image has no alpha channel, return it
|
# If the input image has no alpha channel, return it
|
||||||
if has_alpha is False:
|
if has_alpha is False:
|
||||||
@@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image):
|
||||||
lama = LaMA()
|
with self._context.models.load_remote_model(
|
||||||
return lama(image)
|
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
loader=LaMA.load_jit_model,
|
||||||
|
) as model:
|
||||||
|
lama = LaMA(model)
|
||||||
|
return lama(image)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
107
invokeai/app/invocations/latents_to_image.py
Normal file
107
invokeai/app/invocations/latents_to_image.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import torch
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor2_0,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.stable_diffusion import set_seamless
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"l2i",
|
||||||
|
title="Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
|
category="latents",
|
||||||
|
version="1.2.2",
|
||||||
|
)
|
||||||
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||||
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
|
assert isinstance(vae, torch.nn.Module)
|
||||||
|
latents = latents.to(vae.device)
|
||||||
|
if self.fp32:
|
||||||
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
(
|
||||||
|
AttnProcessor2_0,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# to be in float32 which can save lots of memory
|
||||||
|
if use_torch_2_0_or_xformers:
|
||||||
|
vae.post_quant_conv.to(latents.dtype)
|
||||||
|
vae.decoder.conv_in.to(latents.dtype)
|
||||||
|
vae.decoder.mid_block.to(latents.dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.float()
|
||||||
|
|
||||||
|
else:
|
||||||
|
vae.to(dtype=torch.float16)
|
||||||
|
latents = latents.half()
|
||||||
|
|
||||||
|
if self.tiled or context.config.get().force_tiled_decode:
|
||||||
|
vae.enable_tiling()
|
||||||
|
else:
|
||||||
|
vae.disable_tiling()
|
||||||
|
|
||||||
|
# clear memory as vae decode can request a lot
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# copied from diffusers pipeline
|
||||||
|
latents = latents / vae.config.scaling_factor
|
||||||
|
image = vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
103
invokeai/app/invocations/resize_latents.py
Normal file
103
invokeai/app/invocations/resize_latents.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"lresize",
|
||||||
|
title="Resize Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.2",
|
||||||
|
)
|
||||||
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
width: int = InputField(
|
||||||
|
ge=64,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description=FieldDescriptions.width,
|
||||||
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
ge=64,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description=FieldDescriptions.width,
|
||||||
|
)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
|
latents.to(device),
|
||||||
|
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
|
||||||
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
resized_latents = resized_latents.to("cpu")
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"lscale",
|
||||||
|
title="Scale Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.2",
|
||||||
|
)
|
||||||
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
|
# resizing
|
||||||
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
|
latents.to(device),
|
||||||
|
scale_factor=self.scale_factor,
|
||||||
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
resized_latents = resized_latents.to("cpu")
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
34
invokeai/app/invocations/scheduler.py
Normal file
34
invokeai/app/invocations/scheduler.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
|
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("scheduler_output")
|
||||||
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"scheduler",
|
||||||
|
title="Scheduler",
|
||||||
|
tags=["scheduler"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class SchedulerInvocation(BaseInvocation):
|
||||||
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||||
|
return SchedulerOutput(scheduler=self.scheduler)
|
||||||
@@ -0,0 +1,281 @@
|
|||||||
|
import copy
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
from pydantic import field_validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
|
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import UNetField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||||
|
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||||
|
MultiDiffusionPipeline,
|
||||||
|
MultiDiffusionRegionConditioning,
|
||||||
|
)
|
||||||
|
from invokeai.backend.tiles.tiles import (
|
||||||
|
calc_tiles_min_overlap,
|
||||||
|
)
|
||||||
|
from invokeai.backend.tiles.utils import TBLR
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
|
||||||
|
"""Crop a ControlNetData object to a region."""
|
||||||
|
# Create a shallow copy of the control_data object.
|
||||||
|
control_data_copy = copy.copy(control_data)
|
||||||
|
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||||
|
control_data_copy.image_tensor = control_data.image_tensor[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
|
||||||
|
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
|
||||||
|
]
|
||||||
|
return control_data_copy
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"tiled_multi_diffusion_denoise_latents",
|
||||||
|
title="Tiled Multi-Diffusion Denoise Latents",
|
||||||
|
tags=["upscale", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||||
|
"""Tiled Multi-Diffusion denoising.
|
||||||
|
|
||||||
|
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
|
||||||
|
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
|
||||||
|
different parameters for each region to harness the full power of Multi-Diffusion.
|
||||||
|
|
||||||
|
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
|
||||||
|
T2I-Adapter, masking, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
positive_conditioning: ConditioningField = InputField(
|
||||||
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
|
)
|
||||||
|
negative_conditioning: ConditioningField = InputField(
|
||||||
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||||
|
)
|
||||||
|
noise: LatentsField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.noise,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
latents: LatentsField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tile_height: int = InputField(
|
||||||
|
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
|
||||||
|
)
|
||||||
|
tile_width: int = InputField(
|
||||||
|
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
|
||||||
|
)
|
||||||
|
tile_overlap: int = InputField(
|
||||||
|
default=32,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
gt=0,
|
||||||
|
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
|
||||||
|
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
|
||||||
|
"amount.",
|
||||||
|
)
|
||||||
|
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||||
|
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||||
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
|
)
|
||||||
|
unet: UNetField = InputField(
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
|
)
|
||||||
|
cfg_rescale_multiplier: float = InputField(
|
||||||
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
|
control: ControlField | list[ControlField] | None = InputField(
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("cfg_scale")
|
||||||
|
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||||
|
"""Validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_pipeline(
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: SchedulerMixin,
|
||||||
|
) -> MultiDiffusionPipeline:
|
||||||
|
# TODO(ryand): Get rid of this FakeVae hack.
|
||||||
|
class FakeVae:
|
||||||
|
class FakeVaeConfig:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.block_out_channels = [0]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.config = FakeVae.FakeVaeConfig()
|
||||||
|
|
||||||
|
return MultiDiffusionPipeline(
|
||||||
|
vae=FakeVae(),
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
# Convert tile image-space dimensions to latent-space dimensions.
|
||||||
|
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
|
||||||
|
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
|
||||||
|
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
|
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
|
# Calculate the tile locations to cover the latent-space image.
|
||||||
|
tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=latent_height,
|
||||||
|
image_width=latent_width,
|
||||||
|
tile_height=latent_tile_height,
|
||||||
|
tile_width=latent_tile_width,
|
||||||
|
min_overlap=latent_tile_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the unet's config so that we can pass the base to sd_step_callback().
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||||
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
|
||||||
|
# Load the UNet model.
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
|
||||||
|
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||||
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||||
|
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||||
|
context=context,
|
||||||
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
|
unet=unet,
|
||||||
|
latent_height=latent_tile_height,
|
||||||
|
latent_width=latent_tile_width,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
|
latents_shape=list(latents.shape),
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split the controlnet_data into tiles.
|
||||||
|
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||||
|
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||||
|
for tile in tiles:
|
||||||
|
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||||
|
controlnet_data_tiles.append(tile_controlnet_data)
|
||||||
|
|
||||||
|
# Prepare the MultiDiffusionRegionConditioning list.
|
||||||
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
|
||||||
|
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||||
|
multi_diffusion_conditioning.append(
|
||||||
|
MultiDiffusionRegionConditioning(
|
||||||
|
region=tile,
|
||||||
|
text_conditioning_data=conditioning_data,
|
||||||
|
control_data=tile_controlnet_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
device=unet.device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run Multi-Diffusion denoising.
|
||||||
|
result_latents = pipeline.multi_diffusion_denoise(
|
||||||
|
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||||
|
target_overlap=latent_tile_overlap,
|
||||||
|
latents=latents,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
noise=noise,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
callback=step_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_latents = result_latents.to("cpu")
|
||||||
|
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@@ -10,10 +9,8 @@ from pydantic import ConfigDict
|
|||||||
from invokeai.app.invocations.fields import ImageField
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
esrgan_model_path = None
|
|
||||||
|
|
||||||
if self.model_name in [
|
if self.model_name in [
|
||||||
"RealESRGAN_x4plus.pth",
|
"RealESRGAN_x4plus.pth",
|
||||||
@@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
context.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
loadnet = context.models.load_remote_model(
|
||||||
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
# Downloads the ESRGAN model if it doesn't already exist
|
|
||||||
download_with_progress_bar(
|
|
||||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
upscaler = RealESRGAN(
|
with loadnet as loadnet_model:
|
||||||
scale=netscale,
|
upscaler = RealESRGAN(
|
||||||
model_path=esrgan_model_path,
|
scale=netscale,
|
||||||
model=rrdbnet_model,
|
loadnet=loadnet_model,
|
||||||
half=False,
|
model=rrdbnet_model,
|
||||||
tile=self.tile_size,
|
half=False,
|
||||||
)
|
tile=self.tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: This strips the alpha... is that okay?
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
patchmatch: Enable patchmatch inpaint code.
|
patchmatch: Enable patchmatch inpaint code.
|
||||||
models_dir: Path to the models directory.
|
models_dir: Path to the models directory.
|
||||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||||
|
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||||
db_dir: Path to InvokeAI databases directory.
|
db_dir: Path to InvokeAI databases directory.
|
||||||
outputs_dir: Path to directory for outputs.
|
outputs_dir: Path to directory for outputs.
|
||||||
@@ -112,6 +113,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||||
max_queue_size: Maximum number of items in the session queue.
|
max_queue_size: Maximum number of items in the session queue.
|
||||||
|
clear_queue_on_startup: Empties session queue on startup.
|
||||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||||
node_cache_size: How many cached nodes to keep in memory.
|
node_cache_size: How many cached nodes to keep in memory.
|
||||||
@@ -146,7 +148,8 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
|
|
||||||
# PATHS
|
# PATHS
|
||||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||||
|
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||||
@@ -184,6 +187,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||||
|
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||||
@@ -303,6 +307,11 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||||
return self._resolve(self.convert_cache_dir)
|
return self._resolve(self.convert_cache_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def download_cache_path(self) -> Path:
|
||||||
|
"""Path to the downloaded models directory, resolved to an absolute path.."""
|
||||||
|
return self._resolve(self.download_cache_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def custom_nodes_path(self) -> Path:
|
def custom_nodes_path(self) -> Path:
|
||||||
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
"""Init file for download queue."""
|
"""Init file for download queue."""
|
||||||
|
|
||||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
from .download_base import (
|
||||||
|
DownloadJob,
|
||||||
|
DownloadJobStatus,
|
||||||
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadJob,
|
||||||
|
UnknownJobIDException,
|
||||||
|
)
|
||||||
from .download_default import DownloadQueueService, TqdmProgress
|
from .download_default import DownloadQueueService, TqdmProgress
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DownloadJob",
|
"DownloadJob",
|
||||||
|
"MultiFileDownloadJob",
|
||||||
"DownloadQueueServiceBase",
|
"DownloadQueueServiceBase",
|
||||||
"DownloadQueueService",
|
"DownloadQueueService",
|
||||||
"TqdmProgress",
|
"TqdmProgress",
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobStatus(str, Enum):
|
class DownloadJobStatus(str, Enum):
|
||||||
"""State of a download job."""
|
"""State of a download job."""
|
||||||
@@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
|
|||||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||||
|
|
||||||
|
|
||||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||||
|
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||||
|
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||||
|
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||||
|
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
class DownloadJobBase(BaseModel):
|
||||||
class DownloadJob(BaseModel):
|
"""Base of classes to monitor and control downloads."""
|
||||||
"""Class to monitor and control a model download request."""
|
|
||||||
|
|
||||||
# required variables to be passed in on creation
|
|
||||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
|
||||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
|
||||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
|
||||||
# automatically assigned on creation
|
# automatically assigned on creation
|
||||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
|
||||||
|
|
||||||
# set internally during download process
|
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||||
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
|
||||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
|
||||||
job_ended: Optional[str] = Field(
|
|
||||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
|
||||||
)
|
|
||||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
|
||||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||||
|
|
||||||
@@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
|
|||||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""Return hash of the string representation of this object, for indexing."""
|
|
||||||
return hash(str(self))
|
|
||||||
|
|
||||||
def __le__(self, other: "DownloadJob") -> bool:
|
|
||||||
"""Return True if this job's priority is less than another's."""
|
|
||||||
return self.priority <= other.priority
|
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
"""Call to cancel the job."""
|
"""Call to cancel the job."""
|
||||||
self._cancelled = True
|
self._cancelled = True
|
||||||
@@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
|
|||||||
"""Return true if job completed without errors."""
|
"""Return true if job completed without errors."""
|
||||||
return self.status == DownloadJobStatus.COMPLETED
|
return self.status == DownloadJobStatus.COMPLETED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def waiting(self) -> bool:
|
||||||
|
"""Return true if the job is waiting to run."""
|
||||||
|
return self.status == DownloadJobStatus.WAITING
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
"""Return true if the job is running."""
|
"""Return true if the job is running."""
|
||||||
@@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
|
|||||||
self._on_cancelled = on_cancelled
|
self._on_cancelled = on_cancelled
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class DownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control a model download request."""
|
||||||
|
|
||||||
|
# required variables to be passed in on creation
|
||||||
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||||
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||||
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||||
|
|
||||||
|
# set internally during download process
|
||||||
|
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||||
|
job_ended: Optional[str] = Field(
|
||||||
|
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||||
|
)
|
||||||
|
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return hash of the string representation of this object, for indexing."""
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __le__(self, other: "DownloadJob") -> bool:
|
||||||
|
"""Return True if this job's priority is less than another's."""
|
||||||
|
return self.priority <= other.priority
|
||||||
|
|
||||||
|
|
||||||
|
class MultiFileDownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control multifile downloads."""
|
||||||
|
|
||||||
|
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueueServiceBase(ABC):
|
class DownloadQueueServiceBase(ABC):
|
||||||
"""Multithreaded queue for downloading models via URL."""
|
"""Multithreaded queue for downloading models via URL."""
|
||||||
|
|
||||||
@@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
"""
|
||||||
|
Create and enqueue a multifile download job.
|
||||||
|
|
||||||
|
:param parts: Set of URL / filename pairs
|
||||||
|
:param dest: Path to download to. See below.
|
||||||
|
:param access_token: Access token to download the indicated files. If not provided,
|
||||||
|
each file's URL may be matched to an access token using the config file matching
|
||||||
|
system.
|
||||||
|
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||||
|
you will need to pass the job to submit_multifile_download().
|
||||||
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||||
|
events.
|
||||||
|
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||||
|
|
||||||
|
The `dest` argument is a Path object pointing to a directory. All downloads
|
||||||
|
with be placed inside this directory. The callbacks will receive the
|
||||||
|
MultiFileDownloadJob.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
"""
|
||||||
|
Enqueue a previously-created multi-file download job.
|
||||||
|
|
||||||
|
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def submit_download_job(
|
def submit_download_job(
|
||||||
self,
|
self,
|
||||||
@@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Wait until the indicated download job has reached a terminal state.
|
"""Wait until the indicated download job has reached a terminal state.
|
||||||
|
|
||||||
This will block until the indicated install job has completed,
|
This will block until the indicated install job has completed,
|
||||||
|
|||||||
@@ -8,30 +8,32 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Literal, Optional, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .download_base import (
|
from .download_base import (
|
||||||
DownloadEventHandler,
|
DownloadEventHandler,
|
||||||
DownloadExceptionHandler,
|
DownloadExceptionHandler,
|
||||||
DownloadJob,
|
DownloadJob,
|
||||||
|
DownloadJobBase,
|
||||||
DownloadJobCancelledException,
|
DownloadJobCancelledException,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueueServiceBase,
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadJob,
|
||||||
ServiceInactiveException,
|
ServiceInactiveException,
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
|
||||||
|
|
||||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||||
DOWNLOAD_CHUNK_SIZE = 100000
|
DOWNLOAD_CHUNK_SIZE = 100000
|
||||||
|
|
||||||
@@ -42,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_parallel_dl: int = 5,
|
max_parallel_dl: int = 5,
|
||||||
|
app_config: Optional[InvokeAIAppConfig] = None,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
requests_session: Optional[requests.sessions.Session] = None,
|
requests_session: Optional[requests.sessions.Session] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize DownloadQueue.
|
Initialize DownloadQueue.
|
||||||
|
|
||||||
|
:param app_config: InvokeAIAppConfig object
|
||||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||||
"""
|
"""
|
||||||
|
self._app_config = app_config or get_config()
|
||||||
self._jobs: Dict[int, DownloadJob] = {}
|
self._jobs: Dict[int, DownloadJob] = {}
|
||||||
|
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._job_completed_event = threading.Event()
|
self._job_terminated_event = threading.Event()
|
||||||
self._worker_pool: Set[threading.Thread] = set()
|
self._worker_pool: Set[threading.Thread] = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||||
@@ -107,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
raise ServiceInactiveException(
|
raise ServiceInactiveException(
|
||||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||||
)
|
)
|
||||||
with self._lock:
|
job.id = self._next_id()
|
||||||
job.id = self._next_job_id
|
job.set_callbacks(
|
||||||
self._next_job_id += 1
|
on_start=on_start,
|
||||||
job.set_callbacks(
|
on_progress=on_progress,
|
||||||
on_start=on_start,
|
on_complete=on_complete,
|
||||||
on_progress=on_progress,
|
on_cancelled=on_cancelled,
|
||||||
on_complete=on_complete,
|
on_error=on_error,
|
||||||
on_cancelled=on_cancelled,
|
)
|
||||||
on_error=on_error,
|
self._jobs[job.id] = job
|
||||||
)
|
self._queue.put(job)
|
||||||
self._jobs[job.id] = job
|
|
||||||
self._queue.put(job)
|
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
@@ -141,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
dest=dest,
|
dest=dest,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
access_token=access_token,
|
access_token=access_token or self._lookup_access_token(source),
|
||||||
)
|
)
|
||||||
self.submit_download_job(
|
self.submit_download_job(
|
||||||
job,
|
job,
|
||||||
@@ -153,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
)
|
)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||||
|
mfdj.set_callbacks(
|
||||||
|
on_start=on_start,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_complete=on_complete,
|
||||||
|
on_cancelled=on_cancelled,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
url = part.url
|
||||||
|
path = dest / part.path
|
||||||
|
assert path.is_relative_to(dest), "only relative download paths accepted"
|
||||||
|
job = DownloadJob(
|
||||||
|
source=url,
|
||||||
|
dest=path,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
mfdj.download_parts.add(job)
|
||||||
|
self._download_part2parent[job.source] = mfdj
|
||||||
|
if submit_job:
|
||||||
|
self.submit_multifile_download(mfdj)
|
||||||
|
return mfdj
|
||||||
|
|
||||||
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
for download_job in job.download_parts:
|
||||||
|
self.submit_download_job(
|
||||||
|
download_job,
|
||||||
|
on_start=self._mfd_started,
|
||||||
|
on_progress=self._mfd_progress,
|
||||||
|
on_complete=self._mfd_complete,
|
||||||
|
on_cancelled=self._mfd_cancelled,
|
||||||
|
on_error=self._mfd_error,
|
||||||
|
)
|
||||||
|
|
||||||
def join(self) -> None:
|
def join(self) -> None:
|
||||||
"""Wait for all jobs to complete."""
|
"""Wait for all jobs to complete."""
|
||||||
self._queue.join()
|
self._queue.join()
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
id = self._next_job_id
|
||||||
|
self._next_job_id += 1
|
||||||
|
return id
|
||||||
|
|
||||||
def list_jobs(self) -> List[DownloadJob]:
|
def list_jobs(self) -> List[DownloadJob]:
|
||||||
"""List all the jobs."""
|
"""List all the jobs."""
|
||||||
return list(self._jobs.values())
|
return list(self._jobs.values())
|
||||||
@@ -178,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except KeyError as excp:
|
except KeyError as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""
|
"""
|
||||||
Cancel the indicated job.
|
Cancel the indicated job.
|
||||||
|
|
||||||
If it is running it will be stopped.
|
If it is running it will be stopped.
|
||||||
job.status will be set to DownloadJobStatus.CANCELLED
|
job.status will be set to DownloadJobStatus.CANCELLED
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
||||||
job.cancel()
|
job.cancel()
|
||||||
|
|
||||||
def cancel_all_jobs(self) -> None:
|
def cancel_all_jobs(self) -> None:
|
||||||
@@ -194,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if not job.in_terminal_state:
|
if not job.in_terminal_state:
|
||||||
self.cancel_job(job)
|
self.cancel_job(job)
|
||||||
|
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not job.in_terminal_state:
|
while not job.in_terminal_state:
|
||||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._job_completed_event.clear()
|
self._job_terminated_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
return job
|
return job
|
||||||
@@ -228,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
job.job_started = get_iso_timestamp()
|
job.job_started = get_iso_timestamp()
|
||||||
self._do_download(job)
|
self._do_download(job)
|
||||||
self._signal_job_complete(job)
|
self._signal_job_complete(job)
|
||||||
except (OSError, HTTPError) as excp:
|
|
||||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
|
||||||
job.error = traceback.format_exc()
|
|
||||||
self._signal_job_error(job, excp)
|
|
||||||
except DownloadJobCancelledException:
|
except DownloadJobCancelledException:
|
||||||
self._signal_job_cancelled(job)
|
self._signal_job_cancelled(job)
|
||||||
self._cleanup_cancelled_job(job)
|
self._cleanup_cancelled_job(job)
|
||||||
|
except Exception as excp:
|
||||||
|
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||||
|
job.error = traceback.format_exc()
|
||||||
|
self._signal_job_error(job, excp)
|
||||||
finally:
|
finally:
|
||||||
job.job_ended = get_iso_timestamp()
|
job.job_ended = get_iso_timestamp()
|
||||||
self._job_completed_event.set() # signal a change to terminal state
|
self._job_terminated_event.set() # signal a change to terminal state
|
||||||
|
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||||
|
self._job_terminated_event.set()
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||||
|
|
||||||
def _do_download(self, job: DownloadJob) -> None:
|
def _do_download(self, job: DownloadJob) -> None:
|
||||||
"""Do the actual download."""
|
"""Do the actual download."""
|
||||||
|
|
||||||
url = job.source
|
url = job.source
|
||||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
@@ -335,38 +395,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def _in_progress_path(self, path: Path) -> Path:
|
def _in_progress_path(self, path: Path) -> Path:
|
||||||
return path.with_name(path.name + ".downloading")
|
return path.with_name(path.name + ".downloading")
|
||||||
|
|
||||||
|
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||||
|
# Pull the token from config if it exists and matches the URL
|
||||||
|
token = None
|
||||||
|
for pair in self._app_config.remote_api_tokens or []:
|
||||||
|
if re.search(pair.url_regex, str(source)):
|
||||||
|
token = pair.token
|
||||||
|
break
|
||||||
|
return token
|
||||||
|
|
||||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.RUNNING
|
job.status = DownloadJobStatus.RUNNING
|
||||||
if job.on_start:
|
self._execute_cb(job, "on_start")
|
||||||
try:
|
|
||||||
job.on_start(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_started(job)
|
self._event_bus.emit_download_started(job)
|
||||||
|
|
||||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||||
if job.on_progress:
|
self._execute_cb(job, "on_progress")
|
||||||
try:
|
|
||||||
job.on_progress(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_progress(job)
|
self._event_bus.emit_download_progress(job)
|
||||||
|
|
||||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.COMPLETED
|
job.status = DownloadJobStatus.COMPLETED
|
||||||
if job.on_complete:
|
self._execute_cb(job, "on_complete")
|
||||||
try:
|
|
||||||
job.on_complete(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_complete(job)
|
self._event_bus.emit_download_complete(job)
|
||||||
|
|
||||||
@@ -374,26 +425,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||||
return
|
return
|
||||||
job.status = DownloadJobStatus.CANCELLED
|
job.status = DownloadJobStatus.CANCELLED
|
||||||
if job.on_cancelled:
|
self._execute_cb(job, "on_cancelled")
|
||||||
try:
|
|
||||||
job.on_cancelled(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_cancelled(job)
|
self._event_bus.emit_download_cancelled(job)
|
||||||
|
|
||||||
|
# if multifile download, then signal the parent
|
||||||
|
if parent_job := self._download_part2parent.get(job.source, None):
|
||||||
|
if not parent_job.in_terminal_state:
|
||||||
|
parent_job.status = DownloadJobStatus.CANCELLED
|
||||||
|
self._execute_cb(parent_job, "on_cancelled")
|
||||||
|
|
||||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
job.status = DownloadJobStatus.ERROR
|
job.status = DownloadJobStatus.ERROR
|
||||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||||
if job.on_error:
|
self._execute_cb(job, "on_error", excp)
|
||||||
try:
|
|
||||||
job.on_error(job, excp)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_error(job)
|
self._event_bus.emit_download_error(job)
|
||||||
|
|
||||||
@@ -406,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except OSError as excp:
|
except OSError as excp:
|
||||||
self._logger.warning(excp)
|
self._logger.warning(excp)
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# callbacks used for multifile downloads
|
||||||
|
########################################
|
||||||
|
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"File download started: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.waiting:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.status = DownloadJobStatus.RUNNING
|
||||||
|
assert download_job.download_path is not None
|
||||||
|
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||||
|
mf_job.download_path = (
|
||||||
|
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||||
|
) # keep just the first component of the path
|
||||||
|
self._execute_cb(mf_job, "on_start")
|
||||||
|
|
||||||
|
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.cancelled:
|
||||||
|
for part in mf_job.download_parts:
|
||||||
|
self.cancel_job(part)
|
||||||
|
elif mf_job.running:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
self._execute_cb(mf_job, "on_progress")
|
||||||
|
|
||||||
|
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"Download complete: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
|
||||||
|
# are there any more active jobs left in this task?
|
||||||
|
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
||||||
|
mf_job.status = DownloadJobStatus.COMPLETED
|
||||||
|
self._execute_cb(mf_job, "on_complete")
|
||||||
|
|
||||||
|
# we're done with this sub-job
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
self._logger.warning(f"Download cancelled: {download_job.source}")
|
||||||
|
mf_job.cancel()
|
||||||
|
|
||||||
|
for s in mf_job.download_parts:
|
||||||
|
self.cancel_job(s)
|
||||||
|
|
||||||
|
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
mf_job.status = download_job.status
|
||||||
|
mf_job.error = download_job.error
|
||||||
|
mf_job.error_type = download_job.error_type
|
||||||
|
self._execute_cb(mf_job, "on_error", excp)
|
||||||
|
self._logger.error(
|
||||||
|
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||||
|
)
|
||||||
|
for s in [x for x in mf_job.download_parts if x.running]:
|
||||||
|
self.cancel_job(s)
|
||||||
|
self._download_part2parent.pop(download_job.source)
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _execute_cb(
|
||||||
|
self,
|
||||||
|
job: DownloadJob | MultiFileDownloadJob,
|
||||||
|
callback_name: Literal[
|
||||||
|
"on_start",
|
||||||
|
"on_progress",
|
||||||
|
"on_complete",
|
||||||
|
"on_cancelled",
|
||||||
|
"on_error",
|
||||||
|
],
|
||||||
|
excp: Optional[Exception] = None,
|
||||||
|
) -> None:
|
||||||
|
if callback := getattr(job, callback_name, None):
|
||||||
|
args = [job, excp] if excp else [job]
|
||||||
|
try:
|
||||||
|
callback(*args)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(
|
||||||
|
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pc_name_max(directory: str) -> int:
|
def get_pc_name_max(directory: str) -> int:
|
||||||
if hasattr(os, "pathconf"):
|
if hasattr(os, "pathconf"):
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
ModelInstallCompleteEvent,
|
ModelInstallCompleteEvent,
|
||||||
ModelInstallDownloadProgressEvent,
|
ModelInstallDownloadProgressEvent,
|
||||||
ModelInstallDownloadsCompleteEvent,
|
ModelInstallDownloadsCompleteEvent,
|
||||||
|
ModelInstallDownloadStartedEvent,
|
||||||
ModelInstallErrorEvent,
|
ModelInstallErrorEvent,
|
||||||
ModelInstallStartedEvent,
|
ModelInstallStartedEvent,
|
||||||
ModelLoadCompleteEvent,
|
ModelLoadCompleteEvent,
|
||||||
@@ -34,7 +35,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.download.download_base import DownloadJob
|
from invokeai.app.services.download.download_base import DownloadJob
|
||||||
from invokeai.app.services.events.events_common import EventBase
|
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
@@ -145,6 +145,10 @@ class EventServiceBase:
|
|||||||
|
|
||||||
# region Model install
|
# region Model install
|
||||||
|
|
||||||
|
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
|
||||||
|
"""Emitted at intervals while the install job is started (remote models only)."""
|
||||||
|
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
|
||||||
|
|
||||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
|
|||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
QUEUE_ITEM_STATUS,
|
QUEUE_ITEM_STATUS,
|
||||||
@@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItem,
|
SessionQueueItem,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
@@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
|
|
||||||
@field_validator("invocation", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_invocation(cls, v: Any):
|
|
||||||
"""Validates the invocation using the dynamic type adapter."""
|
|
||||||
|
|
||||||
invocation = BaseInvocation.get_typeadapter().validate_python(v)
|
|
||||||
return invocation
|
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register
|
@payload_schema.register
|
||||||
class InvocationStartedEvent(InvocationEventBase):
|
class InvocationStartedEvent(InvocationEventBase):
|
||||||
@@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
__event_name__ = "invocation_started"
|
__event_name__ = "invocation_started"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
@@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
progress_image: ProgressImage,
|
progress_image: ProgressImage,
|
||||||
) -> "InvocationDenoiseProgressEvent":
|
) -> "InvocationDenoiseProgressEvent":
|
||||||
@@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
|
|
||||||
__event_name__ = "invocation_complete"
|
__event_name__ = "invocation_complete"
|
||||||
|
|
||||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||||
|
|
||||||
@field_validator("result", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_results(cls, v: Any):
|
|
||||||
"""Validates the invocation result using the dynamic type adapter."""
|
|
||||||
|
|
||||||
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||||
) -> "InvocationCompleteEvent":
|
) -> "InvocationCompleteEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
@@ -433,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
|||||||
return cls(config=config, submodel_type=submodel_type)
|
return cls(config=config, submodel_type=submodel_type)
|
||||||
|
|
||||||
|
|
||||||
|
@payload_schema.register
|
||||||
|
class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||||
|
"""Event model for model_install_download_started"""
|
||||||
|
|
||||||
|
__event_name__ = "model_install_download_started"
|
||||||
|
|
||||||
|
id: int = Field(description="The ID of the install job")
|
||||||
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||||
|
local_path: str = Field(description="Where model is downloading to")
|
||||||
|
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||||
|
total_bytes: int = Field(description="Total size of download, including all files")
|
||||||
|
parts: list[dict[str, int | str]] = Field(
|
||||||
|
description="Progress of downloading URLs that comprise the model, if any"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
|
||||||
|
parts: list[dict[str, str | int]] = [
|
||||||
|
{
|
||||||
|
"url": str(x.source),
|
||||||
|
"local_path": str(x.download_path),
|
||||||
|
"bytes": x.bytes,
|
||||||
|
"total_bytes": x.total_bytes,
|
||||||
|
}
|
||||||
|
for x in job.download_parts
|
||||||
|
]
|
||||||
|
return cls(
|
||||||
|
id=job.id,
|
||||||
|
source=str(job.source),
|
||||||
|
local_path=job.local_path.as_posix(),
|
||||||
|
parts=parts,
|
||||||
|
bytes=job.bytes,
|
||||||
|
total_bytes=job.total_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register
|
@payload_schema.register
|
||||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||||
"""Event model for model_install_download_progress"""
|
"""Event model for model_install_download_progress"""
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
from invokeai.backend.model_manager import AnyModelConfig
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallServiceBase(ABC):
|
class ModelInstallServiceBase(ABC):
|
||||||
@@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
:param source: A Url or a string that can be converted into one.
|
:param source: A string representing a URL or repo_id.
|
||||||
:param access_token: Optional access token to access restricted resources.
|
|
||||||
|
|
||||||
The model file will be downloaded into the system-wide model cache
|
The model file will be downloaded into the system-wide model cache
|
||||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
|||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.download import DownloadJob
|
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.config import ModelSourceType
|
from invokeai.backend.model_manager.config import ModelSourceType
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
@@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
|
|||||||
CANCELLED = "cancelled" # terminated with an error message
|
CANCELLED = "cancelled" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallPart(BaseModel):
|
|
||||||
url: AnyHttpUrl
|
|
||||||
path: Path
|
|
||||||
bytes: int = 0
|
|
||||||
total_bytes: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownInstallJobException(Exception):
|
class UnknownInstallJobException(Exception):
|
||||||
"""Raised when the status of an unknown job is requested."""
|
"""Raised when the status of an unknown job is requested."""
|
||||||
|
|
||||||
@@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
|
|||||||
)
|
)
|
||||||
# internal flags and transitory settings
|
# internal flags and transitory settings
|
||||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||||
|
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
|
|||||||
@@ -5,21 +5,22 @@ import os
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from hashlib import sha256
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from pydantic_core import Url
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
@@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
|
|||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import slugify
|
||||||
|
|
||||||
from .model_install_common import (
|
from .model_install_common import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@@ -58,9 +60,6 @@ from .model_install_common import (
|
|||||||
|
|
||||||
TMPDIR_PREFIX = "tmpinstall_"
|
TMPDIR_PREFIX = "tmpinstall_"
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallService(ModelInstallServiceBase):
|
class ModelInstallService(ModelInstallServiceBase):
|
||||||
"""class for InvokeAI model installation."""
|
"""class for InvokeAI model installation."""
|
||||||
@@ -91,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._downloads_changed_event = threading.Event()
|
self._downloads_changed_event = threading.Event()
|
||||||
self._install_completed_event = threading.Event()
|
self._install_completed_event = threading.Event()
|
||||||
self._download_queue = download_queue
|
self._download_queue = download_queue
|
||||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._install_thread: Optional[threading.Thread] = None
|
self._install_thread: Optional[threading.Thread] = None
|
||||||
@@ -210,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
"""Install a model using pattern matching to infer the type of source."""
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
source_obj = self._guess_source(source)
|
||||||
source_obj: Optional[StringLikeSource] = None
|
if isinstance(source_obj, LocalModelSource):
|
||||||
|
source_obj.inplace = inplace
|
||||||
if Path(source).exists(): # A local file or directory
|
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
source_obj.access_token = access_token
|
||||||
elif match := re.match(hf_repoid_re, source):
|
|
||||||
source_obj = HFModelSource(
|
|
||||||
repo_id=match.group(1),
|
|
||||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
|
||||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
|
||||||
# Pull the token from config if it exists and matches the URL
|
|
||||||
_token = access_token
|
|
||||||
if _token is None:
|
|
||||||
for pair in self.app_config.remote_api_tokens or []:
|
|
||||||
if re.search(pair.url_regex, source):
|
|
||||||
_token = pair.token
|
|
||||||
break
|
|
||||||
source_obj = URLModelSource(
|
|
||||||
url=AnyHttpUrl(source),
|
|
||||||
access_token=_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
|
||||||
return self.import_model(source_obj, config)
|
return self.import_model(source_obj, config)
|
||||||
|
|
||||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||||
@@ -297,8 +275,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
"""Cancel the indicated job."""
|
"""Cancel the indicated job."""
|
||||||
job.cancel()
|
job.cancel()
|
||||||
with self._lock:
|
self._logger.warning(f"Cancelling {job.source}")
|
||||||
self._cancel_download_parts(job)
|
if dj := job._multifile_job:
|
||||||
|
self._download_queue.cancel_job(dj)
|
||||||
|
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
@@ -346,7 +325,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
legacy_config_path = stanza.get("config")
|
legacy_config_path = stanza.get("config")
|
||||||
if legacy_config_path:
|
if legacy_config_path:
|
||||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||||
config["config_path"] = str(legacy_config_path)
|
config["config_path"] = str(legacy_config_path)
|
||||||
@@ -386,38 +365,95 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
rmtree(model_path)
|
rmtree(model_path)
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def download_and_cache(
|
@classmethod
|
||||||
|
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
||||||
|
escaped_source = slugify(str(source))
|
||||||
|
return app_config.download_cache_path / escaped_source
|
||||||
|
|
||||||
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, AnyHttpUrl],
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: int = 0,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path."""
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
model_path = self._download_cache_path(str(source), self._app_config)
|
||||||
model_path = self._app_config.convert_cache_path / model_hash
|
|
||||||
|
|
||||||
# We expect the cache directory to contain one and only one downloaded file.
|
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||||
# We don't know the file's name in advance, as it is set by the download
|
# We don't know the file's name in advance, as it is set by the download
|
||||||
# content-disposition header.
|
# content-disposition header.
|
||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
contents: List[Path] = list(model_path.iterdir())
|
||||||
if len(contents) > 0:
|
if len(contents) > 0:
|
||||||
return contents[0]
|
return contents[0]
|
||||||
|
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
job = self._download_queue.download(
|
model_source = self._guess_source(str(source))
|
||||||
source=AnyHttpUrl(str(source)),
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
|
job = self._multifile_download(
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
access_token=access_token,
|
remote_files=remote_files,
|
||||||
on_progress=TqdmProgress().update,
|
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||||
)
|
)
|
||||||
self._download_queue.wait_for_job(job, timeout)
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
|
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||||
|
self._download_queue.wait_for_job(job)
|
||||||
if job.complete:
|
if job.complete:
|
||||||
assert job.download_path is not None
|
assert job.download_path is not None
|
||||||
return job.download_path
|
return job.download_path
|
||||||
else:
|
else:
|
||||||
raise Exception(job.error)
|
raise Exception(job.error)
|
||||||
|
|
||||||
|
def _remote_files_from_source(
|
||||||
|
self, source: ModelSource
|
||||||
|
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||||
|
metadata = None
|
||||||
|
if isinstance(source, HFModelSource):
|
||||||
|
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return (
|
||||||
|
metadata.download_urls(
|
||||||
|
variant=source.variant or self._guess_variant(),
|
||||||
|
subfolder=source.subfolder,
|
||||||
|
session=self._session,
|
||||||
|
),
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(source, URLModelSource):
|
||||||
|
try:
|
||||||
|
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||||
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return metadata.download_urls(session=self._session), metadata
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||||
|
|
||||||
|
raise Exception(f"No files associated with {source}")
|
||||||
|
|
||||||
|
def _guess_source(self, source: str) -> ModelSource:
|
||||||
|
"""Turn a source string into a ModelSource object."""
|
||||||
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
|
if Path(source).exists(): # A local file or directory
|
||||||
|
source_obj = LocalModelSource(path=Path(source))
|
||||||
|
elif match := re.match(hf_repoid_re, source):
|
||||||
|
source_obj = HFModelSource(
|
||||||
|
repo_id=match.group(1),
|
||||||
|
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
|
||||||
|
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||||
|
)
|
||||||
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
|
source_obj = URLModelSource(
|
||||||
|
url=Url(source),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
return source_obj
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the installer threads
|
# Internal functions that manage the installer threads
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@@ -478,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
multifile_download_job = install_job._multifile_job
|
||||||
job.set_error(
|
if multifile_download_job and any(
|
||||||
|
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
|
||||||
|
):
|
||||||
|
install_job.set_error(
|
||||||
InvalidModelConfigException(
|
InvalidModelConfigException(
|
||||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
job.set_error(excp)
|
install_job.set_error(excp)
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(install_job)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the models directory
|
# Internal functions that manage the models directory
|
||||||
@@ -513,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||||
only situations in which we may have orphaned models in the models directory.
|
only situations in which we may have orphaned models in the models directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
installed_model_paths = {
|
installed_model_paths = {
|
||||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||||
}
|
}
|
||||||
@@ -525,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if resolved_path in installed_model_paths:
|
if resolved_path in installed_model_paths:
|
||||||
return True
|
return True
|
||||||
# Skip core models entirely - these aren't registered with the model manager.
|
# Skip core models entirely - these aren't registered with the model manager.
|
||||||
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
|
for special_directory in [
|
||||||
return False
|
self.app_config.models_path / "core",
|
||||||
|
self.app_config.convert_cache_dir,
|
||||||
|
self.app_config.download_cache_dir,
|
||||||
|
]:
|
||||||
|
if resolved_path.is_relative_to(special_directory):
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
model_id = self.register_path(model_path)
|
model_id = self.register_path(model_path)
|
||||||
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
||||||
@@ -641,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_hf(
|
||||||
|
self,
|
||||||
|
source: HFModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
if source.access_token is None:
|
||||||
if not source.access_token:
|
source.access_token = HfFolder.get_token()
|
||||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
|
|
||||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
||||||
remote_files = metadata.download_urls(
|
|
||||||
variant=source.variant or self._guess_variant(),
|
|
||||||
subfolder=source.subfolder,
|
|
||||||
session=self._session,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -662,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(
|
||||||
# URLs from HuggingFace will be handled specially
|
self,
|
||||||
metadata = None
|
source: URLModelSource,
|
||||||
fetcher = None
|
config: Optional[Dict[str, Any]],
|
||||||
try:
|
) -> ModelInstallJob:
|
||||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
|
||||||
if fetcher is not None:
|
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
|
||||||
self._logger.debug(f"metadata={metadata}")
|
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
|
||||||
else:
|
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -692,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
|
||||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
|
||||||
# being held in a daemon thread.
|
|
||||||
if len(remote_files) == 0:
|
if len(remote_files) == 0:
|
||||||
raise ValueError(f"{source}: No downloadable files found")
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
tmpdir = Path(
|
destdir = Path(
|
||||||
mkdtemp(
|
mkdtemp(
|
||||||
dir=self._app_config.models_path,
|
dir=self._app_config.models_path,
|
||||||
prefix=TMPDIR_PREFIX,
|
prefix=TMPDIR_PREFIX,
|
||||||
@@ -708,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
config_in=config or {},
|
config_in=config or {},
|
||||||
source_metadata=metadata,
|
source_metadata=metadata,
|
||||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||||
bytes=0,
|
bytes=0,
|
||||||
total_bytes=0,
|
total_bytes=0,
|
||||||
)
|
)
|
||||||
# In the event that there is a subfolder specified in the source,
|
# remember the temporary directory for later removal
|
||||||
# we need to remove it from the destination path in order to avoid
|
install_job._install_tmpdir = destdir
|
||||||
# creating unwanted subfolders
|
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||||
if isinstance(source, HFModelSource) and source.subfolder:
|
|
||||||
root = Path(remote_files[0].path.parts[0])
|
|
||||||
subfolder = root / source.subfolder
|
|
||||||
else:
|
|
||||||
root = Path(".")
|
|
||||||
subfolder = Path(".")
|
|
||||||
|
|
||||||
# we remember the path up to the top of the tmpdir so that it may be
|
multifile_job = self._multifile_download(
|
||||||
# removed safely at the end of the install process.
|
remote_files=remote_files,
|
||||||
install_job._install_tmpdir = tmpdir
|
dest=destdir,
|
||||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||||
|
access_token=source.access_token,
|
||||||
|
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||||
|
)
|
||||||
|
self._download_cache[multifile_job.id] = install_job
|
||||||
|
install_job._multifile_job = multifile_job
|
||||||
|
|
||||||
files_string = "file" if len(remote_files) == 1 else "file"
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
|
||||||
self._logger.debug(f"remote_files={remote_files}")
|
self._logger.debug(f"remote_files={remote_files}")
|
||||||
for model_file in remote_files:
|
self._download_queue.submit_multifile_download(multifile_job)
|
||||||
url = model_file.url
|
|
||||||
path = root / model_file.path.relative_to(subfolder)
|
|
||||||
self._logger.debug(f"Downloading {url} => {path}")
|
|
||||||
install_job.total_bytes += model_file.size
|
|
||||||
assert hasattr(source, "access_token")
|
|
||||||
dest = tmpdir / path.parent
|
|
||||||
dest.mkdir(parents=True, exist_ok=True)
|
|
||||||
download_job = DownloadJob(
|
|
||||||
source=url,
|
|
||||||
dest=dest,
|
|
||||||
access_token=source.access_token,
|
|
||||||
)
|
|
||||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
|
||||||
install_job.download_parts.add(download_job)
|
|
||||||
|
|
||||||
# only start the jobs once install_job.download_parts is fully populated
|
|
||||||
for download_job in install_job.download_parts:
|
|
||||||
self._download_queue.submit_download_job(
|
|
||||||
download_job,
|
|
||||||
on_start=self._download_started_callback,
|
|
||||||
on_progress=self._download_progress_callback,
|
|
||||||
on_complete=self._download_complete_callback,
|
|
||||||
on_error=self._download_error_callback,
|
|
||||||
on_cancelled=self._download_cancelled_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def _stat_size(self, path: Path) -> int:
|
def _stat_size(self, path: Path) -> int:
|
||||||
@@ -768,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
def _multifile_download(
|
||||||
|
self,
|
||||||
|
remote_files: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
subfolder: Optional[Path] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||||
|
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||||
|
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||||
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||||
|
if subfolder:
|
||||||
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||||
|
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||||
|
path_to_add = Path(f"{top}_{subfolder}")
|
||||||
|
else:
|
||||||
|
path_to_remove = Path(".")
|
||||||
|
path_to_add = Path(".")
|
||||||
|
|
||||||
|
parts: List[RemoteModelFile] = []
|
||||||
|
for model_file in remote_files:
|
||||||
|
assert model_file.size is not None
|
||||||
|
parts.append(
|
||||||
|
RemoteModelFile(
|
||||||
|
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||||
|
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._download_queue.multifile_download(
|
||||||
|
parts=parts,
|
||||||
|
dest=dest,
|
||||||
|
access_token=access_token,
|
||||||
|
submit_job=submit_job,
|
||||||
|
on_start=self._download_started_callback,
|
||||||
|
on_progress=self._download_progress_callback,
|
||||||
|
on_complete=self._download_complete_callback,
|
||||||
|
on_error=self._download_error_callback,
|
||||||
|
on_cancelled=self._download_cancelled_callback,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Callbacks are executed by the download queue in a separate thread
|
# Callbacks are executed by the download queue in a separate thread
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download started: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
install_job.status = InstallStatus.DOWNLOADING
|
install_job.status = InstallStatus.DOWNLOADING
|
||||||
|
|
||||||
assert download_job.download_path
|
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||||
if install_job.local_path == install_job._install_tmpdir:
|
assert download_job.download_path
|
||||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
install_job.local_path = download_job.download_path
|
||||||
dest_name = partial_path.parts[0]
|
install_job.download_parts = download_job.download_parts
|
||||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
|
install_job.total_bytes = download_job.total_bytes
|
||||||
|
self._signal_job_download_started(install_job)
|
||||||
|
|
||||||
# Update the total bytes count for remote sources.
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
if not install_job.total_bytes:
|
|
||||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
|
||||||
|
|
||||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||||
self._cancel_download_parts(install_job)
|
self._download_queue.cancel_job(download_job)
|
||||||
else:
|
else:
|
||||||
# update sizes
|
# update sizes
|
||||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
self._signal_job_downloading(install_job)
|
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
|
||||||
|
self._signal_job_downloading(install_job)
|
||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download complete: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
self._put_in_queue(install_job)
|
self._put_in_queue(install_job) # this starts the installation and registration
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._download_cache.pop(download_job.source, None)
|
self._downloads_changed_event.set()
|
||||||
self._downloads_changed_event.set()
|
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
assert install_job is not None
|
assert excp is not None
|
||||||
assert excp is not None
|
install_job.set_error(excp)
|
||||||
install_job.set_error(excp)
|
self._download_queue.cancel_job(download_job)
|
||||||
self._logger.error(
|
|
||||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
|
||||||
)
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
if not install_job:
|
self._downloads_changed_event.set()
|
||||||
return
|
# if install job has already registered an error, then do not replace its status with cancelled
|
||||||
self._downloads_changed_event.set()
|
if not install_job.errored:
|
||||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
install_job.cancel()
|
||||||
# if install job has already registered an error, then do not replace its status with cancelled
|
|
||||||
if not install_job.errored:
|
|
||||||
install_job.cancel()
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
|
||||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
|
||||||
# do not lock here because it gets called within a locked context
|
|
||||||
for s in install_job.download_parts:
|
|
||||||
self._download_queue.cancel_job(s)
|
|
||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
|
||||||
self._put_in_queue(install_job)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
@@ -859,8 +874,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_started(job)
|
self._event_bus.emit_model_install_started(job)
|
||||||
|
|
||||||
|
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
|
||||||
|
if self._event_bus:
|
||||||
|
assert job._multifile_job is not None
|
||||||
|
assert job.bytes is not None
|
||||||
|
assert job.total_bytes is not None
|
||||||
|
self._event_bus.emit_model_install_download_started(job)
|
||||||
|
|
||||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
|
assert job._multifile_job is not None
|
||||||
|
assert job.bytes is not None
|
||||||
|
assert job.total_bytes is not None
|
||||||
self._event_bus.emit_model_install_download_progress(job)
|
self._event_bus.emit_model_install_download_progress(job)
|
||||||
|
|
||||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||||
@@ -875,6 +900,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"Model install complete: {job.source}")
|
self._logger.info(f"Model install complete: {job.source}")
|
||||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
|
assert job.local_path is not None
|
||||||
|
assert job.config_out is not None
|
||||||
self._event_bus.emit_model_install_complete(job)
|
self._event_bus.emit_model_install_complete(job)
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||||
@@ -890,7 +917,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_cancelled(job)
|
self._event_bus.emit_model_install_cancelled(job)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||||
|
"""
|
||||||
|
Return a metadata fetcher appropriate for provided url.
|
||||||
|
|
||||||
|
This used to be more useful, but the number of supported model
|
||||||
|
sources has been reduced to HuggingFace alone.
|
||||||
|
"""
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
"""Base class for model loader."""
|
"""Base class for model loader."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
@@ -31,3 +32,26 @@ class ModelLoadServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_from_path(
|
||||||
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Load the model file or directory located at the indicated Path.
|
||||||
|
|
||||||
|
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||||
|
argument is provided, the loader will be invoked to load the model into
|
||||||
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
|
torch.load() as appropriate to the file suffix.
|
||||||
|
|
||||||
|
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||||
|
LoadedModel, but without the config attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: A pathlib.Path to a checkpoint-style models file
|
||||||
|
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModel object.
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,18 +1,26 @@
|
|||||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of model loader service."""
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
from typing import Optional, Type
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
|
from torch import load as torch_load
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import (
|
from invokeai.backend.model_manager.load import (
|
||||||
LoadedModel,
|
LoadedModel,
|
||||||
|
LoadedModelWithoutConfig,
|
||||||
ModelLoaderRegistry,
|
ModelLoaderRegistry,
|
||||||
ModelLoaderRegistryBase,
|
ModelLoaderRegistryBase,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_load_base import ModelLoadServiceBase
|
from .model_load_base import ModelLoadServiceBase
|
||||||
@@ -75,3 +83,41 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||||
|
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
def load_model_from_path(
|
||||||
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
cache_key = str(model_path)
|
||||||
|
ram_cache = self.ram_cache
|
||||||
|
try:
|
||||||
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||||
|
result = torch_load(checkpoint, map_location="cpu")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||||
|
load_class = GenericDiffusersLoader(
|
||||||
|
app_config=self._app_config,
|
||||||
|
logger=self._logger,
|
||||||
|
ram_cache=self._ram_cache,
|
||||||
|
convert_cache=self.convert_cache,
|
||||||
|
).get_hf_load_class(directory)
|
||||||
|
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
|
||||||
|
loader = loader or (
|
||||||
|
diffusers_load_directory
|
||||||
|
if model_path.is_dir()
|
||||||
|
else torch_load_file
|
||||||
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||||
|
else lambda path: safetensors_load_file(path, device="cpu")
|
||||||
|
)
|
||||||
|
assert loader is not None
|
||||||
|
raw_model = loader(model_path)
|
||||||
|
ram_cache.put(key=cache_key, model=raw_model)
|
||||||
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
|||||||
@@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
MainModelDefaultSettings,
|
MainModelDefaultSettings,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self._set_in_progress_to_canceled()
|
self._set_in_progress_to_canceled()
|
||||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
if self.__invoker.services.configuration.clear_queue_on_startup:
|
||||||
|
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||||
if prune_result.deleted > 0:
|
if clear_result.deleted > 0:
|
||||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||||
|
else:
|
||||||
|
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||||
|
if prune_result.deleted > 0:
|
||||||
|
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -2,18 +2,19 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
GetCoreSchemaHandler,
|
||||||
GetJsonSchemaHandler,
|
GetJsonSchemaHandler,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from pydantic_core import CoreSchema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
@@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocation(BaseInvocation):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||||
|
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||||
|
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocationOutput(BaseInvocationOutput):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||||
|
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||||
|
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||||
edges: list[Edge] = Field(
|
edges: list[Edge] = Field(
|
||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("nodes", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_nodes(cls, v: dict[str, Any]):
|
|
||||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
|
||||||
|
|
||||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
|
||||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
|
||||||
#
|
|
||||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
|
||||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
|
||||||
# invocations will cause a graph to fail if they are used.
|
|
||||||
#
|
|
||||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
|
||||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
nodes: dict[str, BaseInvocation] = {}
|
|
||||||
typeadapter = BaseInvocation.get_typeadapter()
|
|
||||||
for node_id, node in v.items():
|
|
||||||
nodes[node_id] = typeadapter.validate_python(node)
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
|
||||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
|
||||||
# the generated schema as options for the `nodes` field.
|
|
||||||
#
|
|
||||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
|
||||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
|
||||||
# expected.
|
|
||||||
#
|
|
||||||
# You might be tempted to do something like this:
|
|
||||||
#
|
|
||||||
# ```py
|
|
||||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
|
||||||
# delattr(cloned_model, "validate_nodes")
|
|
||||||
# cloned_model.model_rebuild(force=True)
|
|
||||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
|
||||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
|
||||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
|
||||||
nodes: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The nodes in this graph")
|
|
||||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
|
||||||
|
|
||||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
@@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("results", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
|
||||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
|
||||||
|
|
||||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
|
||||||
results: dict[str, BaseInvocationOutput] = {}
|
|
||||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
|
||||||
for result_id, result in v.items():
|
|
||||||
results[result_id] = typeadapter.validate_python(result)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@field_validator("graph")
|
@field_validator("graph")
|
||||||
def graph_is_valid(cls, v: Graph):
|
def graph_is_valid(cls, v: Graph):
|
||||||
"""Validates that the graph is valid"""
|
"""Validates that the graph is valid"""
|
||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
|
||||||
class GraphExecutionState(BaseModel):
|
|
||||||
"""Tracks the state of a graph execution"""
|
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state")
|
|
||||||
graph: Graph = Field(description="The graph being executed")
|
|
||||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
|
||||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
|
||||||
executed_history: list[str] = Field(
|
|
||||||
description="The list of node ids that have been executed, in order of execution"
|
|
||||||
)
|
|
||||||
results: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The results of node executions")
|
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
|
||||||
prepared_source_mapping: dict[str, str] = Field(
|
|
||||||
description="The map of prepared nodes to original graph nodes"
|
|
||||||
)
|
|
||||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
|
||||||
description="The map of original graph nodes to prepared nodes"
|
|
||||||
)
|
|
||||||
|
|
||||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
@@ -14,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
@@ -320,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
|
"""Common API for loading, downloading and managing models."""
|
||||||
|
|
||||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||||
"""Checks if a model exists.
|
"""Check if a model exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@@ -331,13 +341,13 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
return self._services.model_manager.store.exists(identifier)
|
return self._services.model_manager.store.exists(identifier)
|
||||||
|
else:
|
||||||
return self._services.model_manager.store.exists(identifier.key)
|
return self._services.model_manager.store.exists(identifier.key)
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""Loads a model.
|
"""Load a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@@ -361,7 +371,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
def load_by_attrs(
|
def load_by_attrs(
|
||||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""Loads a model by its attributes.
|
"""Load a model by its attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the model.
|
name: Name of the model.
|
||||||
@@ -384,7 +394,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||||
|
|
||||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||||
"""Gets a model's config.
|
"""Get a model's config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@@ -394,11 +404,11 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
return self._services.model_manager.store.get_model(identifier)
|
return self._services.model_manager.store.get_model(identifier)
|
||||||
|
else:
|
||||||
return self._services.model_manager.store.get_model(identifier.key)
|
return self._services.model_manager.store.get_model(identifier.key)
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by path.
|
"""Search for models by path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: The path to search for.
|
path: The path to search for.
|
||||||
@@ -415,7 +425,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
type: Optional[ModelType] = None,
|
type: Optional[ModelType] = None,
|
||||||
format: Optional[ModelFormat] = None,
|
format: Optional[ModelFormat] = None,
|
||||||
) -> list[AnyModelConfig]:
|
) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by attributes.
|
"""Search for models by attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name to search for (exact match).
|
name: The name to search for (exact match).
|
||||||
@@ -434,6 +444,72 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
model_format=format,
|
model_format=format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def download_and_cache_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
|
This can be used to single-file install models and other resources of arbitrary types
|
||||||
|
which should not get registered with the database. If the model is already
|
||||||
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL that points to the model, or a huggingface repo_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the downloaded model
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
|
|
||||||
|
def load_local_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Load the model file located at the indicated path
|
||||||
|
|
||||||
|
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: A model Path
|
||||||
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModelWithoutConfig object.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
def load_remote_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||||
|
|
||||||
|
If the model is already downloaded, it will be loaded from the cache.
|
||||||
|
|
||||||
|
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL or huggingface repoid.
|
||||||
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModelWithoutConfig object.
|
||||||
|
"""
|
||||||
|
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_8(app_config=config))
|
migrator.register_migration(build_migration_8(app_config=config))
|
||||||
migrator.register_migration(build_migration_9())
|
migrator.register_migration(build_migration_9())
|
||||||
migrator.register_migration(build_migration_10())
|
migrator.register_migration(build_migration_10())
|
||||||
|
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
LEGACY_CORE_MODELS = [
|
||||||
|
# OpenPose
|
||||||
|
"any/annotators/dwpose/yolox_l.onnx",
|
||||||
|
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||||
|
# DepthAnything
|
||||||
|
"any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||||
|
"any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||||
|
"any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||||
|
# Lama inpaint
|
||||||
|
"core/misc/lama/lama.pt",
|
||||||
|
# RealESRGAN upscale
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Migration11Callback:
|
||||||
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
|
self._app_config = app_config
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._remove_convert_cache()
|
||||||
|
self._remove_downloaded_models()
|
||||||
|
self._remove_unused_core_models()
|
||||||
|
|
||||||
|
def _remove_convert_cache(self) -> None:
|
||||||
|
"""Rename models/.cache to models/.convert_cache."""
|
||||||
|
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||||
|
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||||
|
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||||
|
|
||||||
|
def _remove_downloaded_models(self) -> None:
|
||||||
|
"""Remove models from their old locations; they will re-download when needed."""
|
||||||
|
self._logger.info(
|
||||||
|
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
|
||||||
|
)
|
||||||
|
for model_path in LEGACY_CORE_MODELS:
|
||||||
|
legacy_dest_path = self._app_config.models_path / model_path
|
||||||
|
legacy_dest_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _remove_unused_core_models(self) -> None:
|
||||||
|
"""Remove unused core models and their directories."""
|
||||||
|
self._logger.info("Removing defunct core models.")
|
||||||
|
for dir in ["face_restoration", "misc", "upscaling"]:
|
||||||
|
path_to_remove = self._app_config.models_path / "core" / dir
|
||||||
|
shutil.rmtree(path_to_remove, ignore_errors=True)
|
||||||
|
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 10 to 11.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||||
|
"models/.download_cache" directory.
|
||||||
|
- Renames "models/.cache" to "models/.convert_cache".
|
||||||
|
"""
|
||||||
|
migration_11 = Migration(
|
||||||
|
from_version=10,
|
||||||
|
to_version=11,
|
||||||
|
callback=Migration11Callback(app_config=app_config, logger=logger),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_11
|
||||||
@@ -289,7 +289,7 @@ def prepare_control_image(
|
|||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
device: str = "cuda",
|
device: str | torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||||
@@ -304,7 +304,7 @@ def prepare_control_image(
|
|||||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|||||||
116
invokeai/app/util/custom_openapi.py
Normal file
116
invokeai/app/util/custom_openapi.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||||
|
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
|
|
||||||
|
|
||||||
|
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||||
|
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||||
|
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||||
|
component_schema."""
|
||||||
|
|
||||||
|
defs = component_schema.pop("$defs", {})
|
||||||
|
for schema_key, json_schema in defs.items():
|
||||||
|
if schema_key in openapi_schema["components"]["schemas"]:
|
||||||
|
continue
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_func(
|
||||||
|
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||||
|
) -> Callable[[], dict[str, Any]]:
|
||||||
|
"""Gets the OpenAPI schema generator function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (FastAPI): The FastAPI app to generate the schema for.
|
||||||
|
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||||
|
generated schema before returning it. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||||
|
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||||
|
matches FastAPI's default schema generation caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def openapi() -> dict[str, Any]:
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
description="An API for invoking AI image operations",
|
||||||
|
version="1.0.0",
|
||||||
|
routes=app.routes,
|
||||||
|
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||||
|
invocation_output_map_properties: dict[str, Any] = {}
|
||||||
|
invocation_output_map_required: list[str] = []
|
||||||
|
|
||||||
|
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||||
|
for output in BaseInvocationOutput.get_outputs():
|
||||||
|
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||||
|
|
||||||
|
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||||
|
# property, so we'll just do it all manually.
|
||||||
|
for invocation in BaseInvocation.get_invocations():
|
||||||
|
json_schema = invocation.model_json_schema(
|
||||||
|
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||||
|
)
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
output_title = invocation.get_output_annotation().__name__
|
||||||
|
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||||
|
json_schema["output"] = outputs_ref
|
||||||
|
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||||
|
|
||||||
|
# Add this invocation and its output to the output map
|
||||||
|
invocation_type = invocation.get_type()
|
||||||
|
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||||
|
invocation_output_map_required.append(invocation_type)
|
||||||
|
|
||||||
|
# Add the output map to the schema
|
||||||
|
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": invocation_output_map_properties,
|
||||||
|
"required": invocation_output_map_required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||||
|
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||||
|
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||||
|
# `models_json_schema` seems to work fine.
|
||||||
|
additional_models = [
|
||||||
|
*EventBase.get_events(),
|
||||||
|
UIConfigBase,
|
||||||
|
InputFieldJSONSchemaExtra,
|
||||||
|
OutputFieldJSONSchemaExtra,
|
||||||
|
ModelIdentifierField,
|
||||||
|
ProgressImage,
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_schemas = models_json_schema(
|
||||||
|
[(m, "serialization") for m in additional_models],
|
||||||
|
ref_template="#/components/schemas/{model}",
|
||||||
|
)
|
||||||
|
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||||
|
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||||
|
|
||||||
|
if post_transform is not None:
|
||||||
|
openapi_schema = post_transform(openapi_schema)
|
||||||
|
|
||||||
|
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
return openapi
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressBar:
|
|
||||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "file"):
|
|
||||||
self.pbar = None
|
|
||||||
self.name = model_name
|
|
||||||
|
|
||||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
|
||||||
if not self.pbar:
|
|
||||||
self.pbar = tqdm(
|
|
||||||
desc=self.name,
|
|
||||||
initial=0,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1000,
|
|
||||||
total=total_size,
|
|
||||||
)
|
|
||||||
self.pbar.update(block_size)
|
|
||||||
|
|
||||||
|
|
||||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
|
||||||
"""Download a file from a URL to a destination path, with a progress bar.
|
|
||||||
If the file already exists, it will not be downloaded again.
|
|
||||||
|
|
||||||
Exceptions are not caught.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name of the file being downloaded.
|
|
||||||
url (str): URL to download the file from.
|
|
||||||
dest_path (Path): Destination path to save the file to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the file was downloaded, False if it already existed.
|
|
||||||
"""
|
|
||||||
if dest_path.exists():
|
|
||||||
return False # already downloaded
|
|
||||||
|
|
||||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
|
||||||
|
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
|
||||||
|
|
||||||
return True
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import pathlib
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
from typing import Literal
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,28 +10,17 @@ from PIL import Image
|
|||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
DEPTH_ANYTHING_MODELS = {
|
||||||
"large": {
|
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||||
},
|
|
||||||
"base": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
|
||||||
},
|
|
||||||
"small": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -53,36 +42,27 @@ transform = Compose(
|
|||||||
|
|
||||||
|
|
||||||
class DepthAnythingDetector:
|
class DepthAnythingDetector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||||
self.model = None
|
self.model = model
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.device = device
|
||||||
self.device = TorchDevice.choose_torch_device()
|
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
@staticmethod
|
||||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
def load_model(
|
||||||
download_with_progress_bar(
|
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
) -> DPT_DINOv2:
|
||||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
match model_size:
|
||||||
DEPTH_ANYTHING_MODEL_PATH,
|
case "small":
|
||||||
)
|
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||||
|
case "base":
|
||||||
|
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
|
case "large":
|
||||||
|
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
|
|
||||||
if not self.model or model_size != self.model_size:
|
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||||
del self.model
|
model.eval()
|
||||||
self.model_size = model_size
|
|
||||||
|
|
||||||
match self.model_size:
|
model.to(device)
|
||||||
case "small":
|
return model
|
||||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
|
||||||
case "base":
|
|
||||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
|
||||||
case "large":
|
|
||||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
|
|||||||
@@ -1,30 +1,53 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from controlnet_aux.util import resize_image
|
from controlnet_aux.util import resize_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||||
|
|
||||||
|
DWPOSE_MODELS = {
|
||||||
|
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||||
|
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||||
|
}
|
||||||
|
|
||||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
|
||||||
|
def draw_pose(
|
||||||
|
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||||
|
H: int,
|
||||||
|
W: int,
|
||||||
|
draw_face: bool = True,
|
||||||
|
draw_body: bool = True,
|
||||||
|
draw_hands: bool = True,
|
||||||
|
resolution: int = 512,
|
||||||
|
) -> Image.Image:
|
||||||
bodies = pose["bodies"]
|
bodies = pose["bodies"]
|
||||||
faces = pose["faces"]
|
faces = pose["faces"]
|
||||||
hands = pose["hands"]
|
hands = pose["hands"]
|
||||||
|
|
||||||
|
assert isinstance(bodies, dict)
|
||||||
candidate = bodies["candidate"]
|
candidate = bodies["candidate"]
|
||||||
|
|
||||||
|
assert isinstance(bodies, dict)
|
||||||
subset = bodies["subset"]
|
subset = bodies["subset"]
|
||||||
|
|
||||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||||
|
|
||||||
if draw_body:
|
if draw_body:
|
||||||
canvas = draw_bodypose(canvas, candidate, subset)
|
canvas = draw_bodypose(canvas, candidate, subset)
|
||||||
|
|
||||||
if draw_hands:
|
if draw_hands:
|
||||||
|
assert isinstance(hands, np.ndarray)
|
||||||
canvas = draw_handpose(canvas, hands)
|
canvas = draw_handpose(canvas, hands)
|
||||||
|
|
||||||
if draw_face:
|
if draw_face:
|
||||||
canvas = draw_facepose(canvas, faces)
|
assert isinstance(hands, np.ndarray)
|
||||||
|
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||||
|
|
||||||
dwpose_image = resize_image(
|
dwpose_image: Image.Image = resize_image(
|
||||||
canvas,
|
canvas,
|
||||||
resolution,
|
resolution,
|
||||||
)
|
)
|
||||||
@@ -39,11 +62,16 @@ class DWOpenposeDetector:
|
|||||||
Credits: https://github.com/IDEA-Research/DWPose
|
Credits: https://github.com/IDEA-Research/DWPose
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||||
self.pose_estimation = Wholebody()
|
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
draw_face: bool = False,
|
||||||
|
draw_body: bool = True,
|
||||||
|
draw_hands: bool = False,
|
||||||
|
resolution: int = 512,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
np_image = np.array(image)
|
np_image = np.array(image)
|
||||||
H, W, C = np_image.shape
|
H, W, C = np_image.shape
|
||||||
@@ -79,3 +107,6 @@ class DWOpenposeDetector:
|
|||||||
return draw_pose(
|
return draw_pose(
|
||||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import math
|
|||||||
import cv2
|
import cv2
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
eps = 0.01
|
eps = 0.01
|
||||||
|
NDArrayInt = npt.NDArray[np.uint8]
|
||||||
|
|
||||||
|
|
||||||
def draw_bodypose(canvas, candidate, subset):
|
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
candidate = np.array(candidate)
|
candidate = np.array(candidate)
|
||||||
subset = np.array(subset)
|
subset = np.array(subset)
|
||||||
@@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
|
|||||||
return canvas
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
def draw_handpose(canvas, all_hand_peaks):
|
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
@@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
|
|||||||
return canvas
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
def draw_facepose(canvas, all_lmks):
|
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
for lmks in all_lmks:
|
for lmks in all_lmks:
|
||||||
lmks = np.array(lmks)
|
lmks = np.array(lmks)
|
||||||
|
|||||||
@@ -2,47 +2,26 @@
|
|||||||
# Modified pathing to suit Invoke
|
# Modified pathing to suit Invoke
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
|
|
||||||
DWPOSE_MODELS = {
|
|
||||||
"yolox_l.onnx": {
|
|
||||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
|
||||||
},
|
|
||||||
"dw-ll_ucoco_384.onnx": {
|
|
||||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self):
|
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||||
device = TorchDevice.choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
|
||||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
|
||||||
|
|
||||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
|
||||||
download_with_progress_bar(
|
|
||||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
onnx_det = DET_MODEL_PATH
|
|
||||||
onnx_pose = POSE_MODEL_PATH
|
|
||||||
|
|
||||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import gc
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -6,9 +6,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
@@ -19,28 +17,11 @@ def norm_img(np_img):
|
|||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
def load_jit_model(url_or_path, device):
|
|
||||||
model_path = url_or_path
|
|
||||||
logger.info(f"Loading model from: {model_path}")
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class LaMA:
|
class LaMA:
|
||||||
|
def __init__(self, model: AnyModel):
|
||||||
|
self._model = model
|
||||||
|
|
||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = TorchDevice.choose_torch_device()
|
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
|
||||||
|
|
||||||
if not model_location.exists():
|
|
||||||
download_with_progress_bar(
|
|
||||||
name="LaMa Inpainting Model",
|
|
||||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
dest_path=model_location,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = load_jit_model(model_location, device)
|
|
||||||
|
|
||||||
image = np.asarray(input_image.convert("RGB"))
|
image = np.asarray(input_image.convert("RGB"))
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
|
|
||||||
@@ -48,20 +29,25 @@ class LaMA:
|
|||||||
mask = np.asarray(mask)
|
mask = np.asarray(mask)
|
||||||
mask = np.invert(mask)
|
mask = np.invert(mask)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
mask = (mask > 0) * 1
|
mask = (mask > 0) * 1
|
||||||
|
|
||||||
|
device = next(self._model.buffers()).device
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
infilled_image = model(image, mask)
|
infilled_image = self._model(image, mask)
|
||||||
|
|
||||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||||
infilled_image = Image.fromarray(infilled_image)
|
infilled_image = Image.fromarray(infilled_image)
|
||||||
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return infilled_image
|
return infilled_image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
|
||||||
|
model_path = url_or_path
|
||||||
|
logger.info(f"Loading model from: {model_path}")
|
||||||
|
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -52,7 +52,7 @@ class RealESRGAN:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scale: int,
|
scale: int,
|
||||||
model_path: Path,
|
loadnet: AnyModel,
|
||||||
model: RRDBNet,
|
model: RRDBNet,
|
||||||
tile: int = 0,
|
tile: int = 0,
|
||||||
tile_pad: int = 10,
|
tile_pad: int = 10,
|
||||||
@@ -67,8 +67,6 @@ class RealESRGAN:
|
|||||||
self.half = half
|
self.half = half
|
||||||
self.device = TorchDevice.choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
|
||||||
|
|
||||||
# prefer to use params_ema
|
# prefer to use params_ema
|
||||||
if "params_ema" in loadnet:
|
if "params_ema" in loadnet:
|
||||||
keyname = "params_ema"
|
keyname = "params_ema"
|
||||||
|
|||||||
@@ -125,13 +125,16 @@ class IPAdapter(RawModel):
|
|||||||
self.device, dtype=self.dtype
|
self.device, dtype=self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
def to(
|
||||||
self.device = device
|
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||||
|
):
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
def calc_size(self):
|
def calc_size(self):
|
||||||
# workaround for circular import
|
# workaround for circular import
|
||||||
|
|||||||
@@ -61,9 +61,10 @@ class LoRALayerBase:
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
@@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
class LoHALayer(LoRALayerBase):
|
||||||
@@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
if self.t1 is not None:
|
if self.t1 is not None:
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
class LoKRLayer(LoRALayerBase):
|
||||||
@@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
@@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
assert self.w1_b is not None
|
assert self.w1_b is not None
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.w2 is not None:
|
if self.w2 is not None:
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
assert self.w2_a is not None
|
assert self.w2_a is not None
|
||||||
assert self.w2_b is not None
|
assert self.w2_b is not None
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
@@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class IA3Layer(LoRALayerBase):
|
class IA3Layer(LoRALayerBase):
|
||||||
@@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
):
|
):
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
@@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for _key, layer in self.layers.items():
|
for _key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
@@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
state_dict[layer_key].clear()
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||||
model.layers[layer_key] = layer
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
26
invokeai/backend/model_hash/hash_validator.py
Normal file
26
invokeai/backend/model_hash/hash_validator.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import json
|
||||||
|
from base64 import b64decode
|
||||||
|
|
||||||
|
|
||||||
|
def validate_hash(hash: str):
|
||||||
|
if ":" not in hash:
|
||||||
|
return
|
||||||
|
for enc_hash in hashes:
|
||||||
|
alg, hash_ = hash.split(":")
|
||||||
|
if alg == "blake3":
|
||||||
|
alg = "blake3_single"
|
||||||
|
map = json.loads(b64decode(enc_hash))
|
||||||
|
if alg in map:
|
||||||
|
if hash_ == map[alg]:
|
||||||
|
raise Exception(
|
||||||
|
"This model can not be loaded. If you're looking for help, consider visiting https://www.redirectionprogram.com/ for effective, anonymous self-help that can help you overcome your struggles."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
hashes: list[str] = [
|
||||||
|
"eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
|
||||||
|
"eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
|
||||||
|
"eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
|
||||||
|
"eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
|
||||||
|
"eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
|
||||||
|
]
|
||||||
@@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
|
|||||||
|
|
||||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
@@ -115,7 +116,7 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
Default = "" # model files without "fp16" or other qualifier
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
@@ -448,4 +449,6 @@ class ModelConfigFactory(object):
|
|||||||
model.key = key
|
model.key = key
|
||||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||||
model.converted_at = timestamp
|
model.converted_at = timestamp
|
||||||
|
if model:
|
||||||
|
validate_hash(model.hash)
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
|
|||||||
@@ -30,12 +30,8 @@ def convert_ldm_vae_to_diffusers(
|
|||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
with torch.no_grad():
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.to(precision)
|
||||||
del converted_vae_checkpoint # Free memory
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
vae.to(precision)
|
|
||||||
|
|
||||||
if dump_path:
|
if dump_path:
|
||||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||||
@@ -56,11 +52,7 @@ def convert_ckpt_to_diffusers(
|
|||||||
model to be written.
|
model to be written.
|
||||||
"""
|
"""
|
||||||
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
|
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
|
||||||
with torch.no_grad():
|
pipe = pipe.to(precision)
|
||||||
del kwargs # Free memory
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
pipe = pipe.to(precision)
|
|
||||||
|
|
||||||
# TO DO: save correct repo variant
|
# TO DO: save correct repo variant
|
||||||
if dump_path:
|
if dump_path:
|
||||||
@@ -83,11 +75,7 @@ def convert_controlnet_to_diffusers(
|
|||||||
model to be written.
|
model to be written.
|
||||||
"""
|
"""
|
||||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
|
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
|
||||||
with torch.no_grad():
|
pipe = pipe.to(precision)
|
||||||
del kwargs # Free memory
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
pipe = pipe.to(precision)
|
|
||||||
|
|
||||||
# TO DO: save correct repo variant
|
# TO DO: save correct repo variant
|
||||||
if dump_path:
|
if dump_path:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from importlib import import_module
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import LoadedModel, ModelLoaderBase
|
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||||
from .load_default import ModelLoader
|
from .load_default import ModelLoader
|
||||||
from .model_cache.model_cache_default import ModelCache
|
from .model_cache.model_cache_default import ModelCache
|
||||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||||
@@ -19,6 +19,7 @@ for module in loaders:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LoadedModel",
|
"LoadedModel",
|
||||||
|
"LoadedModelWithoutConfig",
|
||||||
"ModelCache",
|
"ModelCache",
|
||||||
"ModelConvertCache",
|
"ModelConvertCache",
|
||||||
"ModelLoaderBase",
|
"ModelLoaderBase",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from invokeai.backend.util import GIG, directory_size
|
from invokeai.backend.util import GIG, directory_size
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.backend.util.util import safe_filename
|
||||||
|
|
||||||
from .convert_cache_base import ModelConvertCacheBase
|
from .convert_cache_base import ModelConvertCacheBase
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
|
|||||||
|
|
||||||
def cache_path(self, key: str) -> Path:
|
def cache_path(self, key: str) -> Path:
|
||||||
"""Return the path for a model with the indicated key."""
|
"""Return the path for a model with the indicated key."""
|
||||||
|
key = safe_filename(self._cache_path, key)
|
||||||
return self._cache_path / key
|
return self._cache_path / key
|
||||||
|
|
||||||
def make_room(self, size: float) -> None:
|
def make_room(self, size: float) -> None:
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@@ -20,10 +23,44 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadedModel:
|
class LoadedModelWithoutConfig:
|
||||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
"""
|
||||||
|
Context manager object that mediates transfer from RAM<->VRAM.
|
||||||
|
|
||||||
|
This is a context manager object that has two distinct APIs:
|
||||||
|
|
||||||
|
1. Older API (deprecated):
|
||||||
|
Use the LoadedModel object directly as a context manager.
|
||||||
|
It will move the model into VRAM (on CUDA devices), and
|
||||||
|
return the model in a form suitable for passing to torch.
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||||
|
with loaded_model as vae:
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Newer API (recommended):
|
||||||
|
Call the LoadedModel's `model_on_device()` method in a
|
||||||
|
context. It returns a tuple consisting of a copy of
|
||||||
|
the model's state dict in CPU RAM followed by a copy
|
||||||
|
of the model in VRAM. The state dict is provided to allow
|
||||||
|
LoRAs and other model patchers to return the model to
|
||||||
|
its unpatched state without expensive copy and restore
|
||||||
|
operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||||
|
with loaded_model.model_on_device() as (state_dict, vae):
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
The state_dict should be treated as a read-only object and
|
||||||
|
never modified. Also be aware that some loadable models do
|
||||||
|
not have a state_dict, in which case this value will be None.
|
||||||
|
"""
|
||||||
|
|
||||||
config: AnyModelConfig
|
|
||||||
_locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
|
|
||||||
def __enter__(self) -> AnyModel:
|
def __enter__(self) -> AnyModel:
|
||||||
@@ -35,12 +72,29 @@ class LoadedModel:
|
|||||||
"""Context exit."""
|
"""Context exit."""
|
||||||
self._locker.unlock()
|
self._locker.unlock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||||
|
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||||
|
locked_model = self._locker.lock()
|
||||||
|
try:
|
||||||
|
state_dict = self._locker.get_state_dict()
|
||||||
|
yield (state_dict, locked_model)
|
||||||
|
finally:
|
||||||
|
self._locker.unlock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
"""Return the model without locking it."""
|
"""Return the model without locking it."""
|
||||||
return self._locker.model
|
return self._locker.model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadedModel(LoadedModelWithoutConfig):
|
||||||
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
|
config: Optional[AnyModelConfig] = None
|
||||||
|
|
||||||
|
|
||||||
# TODO(MM2):
|
# TODO(MM2):
|
||||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||||
# know about. I think the problem may be related to this class being an ABC.
|
# know about. I think the problem may be related to this class being an ABC.
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
|||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||||
if self._needs_conversion(config, model_path, cache_path):
|
if self._needs_conversion(config, model_path, cache_path):
|
||||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||||
else:
|
else:
|
||||||
@@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
config.key,
|
config.key,
|
||||||
submodel_type=submodel_type,
|
submodel_type=submodel_type,
|
||||||
model=loaded_model,
|
model=loaded_model,
|
||||||
size=calc_model_size_by_data(loaded_model),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._ram_cache.get(
|
return self._ram_cache.get(
|
||||||
@@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
if subtype == submodel_type:
|
if subtype == submodel_type:
|
||||||
continue
|
continue
|
||||||
if submodel := getattr(pipeline, subtype.value, None):
|
if submodel := getattr(pipeline, subtype.value, None):
|
||||||
self._ram_cache.put(
|
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
|
||||||
)
|
|
||||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
|||||||
@@ -30,6 +30,11 @@ class ModelLockerBase(ABC):
|
|||||||
"""Unlock the contained model, and remove it from VRAM."""
|
"""Unlock the contained model, and remove it from VRAM."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
|
"""Return the state dict (if any) for the cached model."""
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
@@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
|
|||||||
and then injected into the model. When the model is finished, the VRAM
|
and then injected into the model. When the model is finished, the VRAM
|
||||||
copy of the state dict is deleted, and the RAM version is reinjected
|
copy of the state dict is deleted, and the RAM version is reinjected
|
||||||
into the model.
|
into the model.
|
||||||
|
|
||||||
|
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||||
|
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||||
|
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||||
|
context manager call `model_on_device()`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
@@ -159,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: T,
|
model: T,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@@ -153,13 +154,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||||
@@ -252,12 +253,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
"""
|
"""
|
||||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
|
||||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
|
||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
|
||||||
return
|
|
||||||
|
|
||||||
source_device = cache_entry.device
|
source_device = cache_entry.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||||
@@ -265,6 +261,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||||
|
if not hasattr(cache_entry.model, "to"):
|
||||||
|
return
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||||
@@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
new_dict: Dict[str, torch.Tensor] = {}
|
new_dict: Dict[str, torch.Tensor] = {}
|
||||||
for k, v in cache_entry.state_dict.items():
|
for k, v in cache_entry.state_dict.items():
|
||||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device, non_blocking=True)
|
||||||
cache_entry.device = target_device
|
cache_entry.device = target_device
|
||||||
except Exception as e: # blow away cache entry
|
except Exception as e: # blow away cache entry
|
||||||
self._delete_cache_entry(cache_entry)
|
self._delete_cache_entry(cache_entry)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
Base class and implementation of a class that moves models in and out of VRAM.
|
Base class and implementation of a class that moves models in and out of VRAM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
@@ -27,20 +29,18 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Return the model without moving it around."""
|
"""Return the model without moving it around."""
|
||||||
return self._cache_entry.model
|
return self._cache_entry.model
|
||||||
|
|
||||||
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
|
"""Return the state dict (if any) for the cached model."""
|
||||||
|
return self._cache_entry.state_dict
|
||||||
|
|
||||||
def lock(self) -> AnyModel:
|
def lock(self) -> AnyModel:
|
||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
|
|
||||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||||
self._cache_entry.loaded = True
|
self._cache_entry.loaded = True
|
||||||
|
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
@@ -55,10 +55,7 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return
|
|
||||||
|
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
if not self._cache.lazy_offloading:
|
if not self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(0)
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
|||||||
@@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
class_name = config.get("_class_name", None)
|
if class_name := config.get("_class_name"):
|
||||||
if class_name:
|
|
||||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
if config.get("model_type", None) == "clip_vision_model":
|
elif class_name := config.get("architectures"):
|
||||||
class_name = config.get("architectures")
|
|
||||||
assert class_name is not None
|
|
||||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||||
if not class_name:
|
else:
|
||||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
|
||||||
class VAELoader(GenericDiffusersLoader):
|
class VAELoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
@@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||||
# TODO(MM2): check whether sdxl VAE models convert.
|
assert isinstance(config, CheckpointConfigBase)
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
|
||||||
else:
|
|
||||||
assert isinstance(config, CheckpointConfigBase)
|
|
||||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
assert s.size is not None
|
assert s.size is not None
|
||||||
files.append(
|
files.append(
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||||
path=Path(name, s.rfilename),
|
path=Path(name, s.rfilename),
|
||||||
size=s.size,
|
size=s.size,
|
||||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||||
|
|||||||
@@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
|
|||||||
|
|
||||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||||
path: Path = Field(description="The path to the file, relative to the model root")
|
path: Path = Field(description="The path to the file, relative to the model root")
|
||||||
size: int = Field(description="The size of this file, in bytes")
|
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataBase(BaseModel):
|
class ModelMetadataBase(BaseModel):
|
||||||
"""Base class for model metadata information."""
|
"""Base class for model metadata information."""
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
# I can't find any standalone 2.X VAEs to test with!
|
# VAEs of all base types have the same structure, so we wimp out and
|
||||||
return BaseModelType.StableDiffusion1
|
# guess using the name.
|
||||||
|
for regexp, basetype in [
|
||||||
|
(r"xl", BaseModelType.StableDiffusionXL),
|
||||||
|
(r"sd2", BaseModelType.StableDiffusion2),
|
||||||
|
(r"vae", BaseModelType.StableDiffusion1),
|
||||||
|
]:
|
||||||
|
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||||
|
return basetype
|
||||||
|
raise InvalidModelConfigException("Cannot determine base type")
|
||||||
|
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -66,8 +66,14 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
) -> Generator[None, None, None]:
|
||||||
|
with cls.apply_lora(
|
||||||
|
unet,
|
||||||
|
loras=loras,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -76,28 +82,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
) -> Generator[None, None, None]:
|
||||||
yield
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder2(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -107,7 +94,16 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> None:
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
|
"""
|
||||||
original_weights = {}
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -133,19 +129,22 @@ class ModelPatcher:
|
|||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
if module_key not in original_weights:
|
if module_key not in original_weights:
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||||
|
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||||
|
else:
|
||||||
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
# same thing in a single call to '.to(...)'.
|
# same thing in a single call to '.to(...)'.
|
||||||
layer.to(device=device)
|
layer.to(device=device, non_blocking=True)
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32, non_blocking=True)
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
layer.to(device=torch.device("cpu"))
|
layer.to(device=torch.device("cpu"), non_blocking=True)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
@@ -154,7 +153,7 @@ class ModelPatcher:
|
|||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
@@ -162,7 +161,7 @@ class ModelPatcher:
|
|||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
|
import torch
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
|
||||||
@@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
|
|||||||
# return self.io_binding.copy_outputs_to_cpu()
|
# return self.io_binding.copy_outputs_to_cpu()
|
||||||
return self.session.run(None, inputs)
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
|
# compatability with RawModel ABC
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
# compatability with diffusers load code
|
# compatability with diffusers load code
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
|||||||
@@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
|||||||
that adds additional methods and attributes.
|
that adds additional methods and attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class RawModel:
|
import torch
|
||||||
"""Base class for 'Raw' model wrappers."""
|
|
||||||
|
|
||||||
|
class RawModel(ABC):
|
||||||
|
"""Abstract base class for 'Raw' model wrappers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -10,12 +10,11 @@ import PIL.Image
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
|
|||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AddsMaskLatents:
|
|
||||||
"""Add the channels required for inpainting model input.
|
|
||||||
|
|
||||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
|
||||||
and the latent encoding of the base image.
|
|
||||||
|
|
||||||
This class assumes the same mask and base image should apply to all items in the batch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
|
||||||
mask: torch.Tensor
|
|
||||||
initial_image_latents: torch.Tensor
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
text_embeddings: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
model_input = self.add_mask_channels(latents)
|
|
||||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
|
||||||
|
|
||||||
def add_mask_channels(self, latents):
|
|
||||||
batch_size = latents.size(0)
|
|
||||||
# duplicate mask and latents for each batch
|
|
||||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
# add mask and image as additional channels
|
|
||||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
|
||||||
return model_input
|
|
||||||
|
|
||||||
|
|
||||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
|
||||||
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskGuidance:
|
class AddsMaskGuidance:
|
||||||
mask: torch.FloatTensor
|
mask: torch.Tensor
|
||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.Tensor
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
noise: torch.Tensor
|
noise: torch.Tensor
|
||||||
gradient_mask: bool
|
is_gradient_mask: bool
|
||||||
|
|
||||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
return self.apply_mask(latents, t)
|
return self.apply_mask(latents, t)
|
||||||
|
|
||||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if t.dim() == 0:
|
if t.dim() == 0:
|
||||||
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
|
|||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if self.gradient_mask:
|
if self.is_gradient_mask:
|
||||||
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
||||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
control_model: ControlNetModel = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
self.control_model = control_model
|
|
||||||
self.use_ip_adapter = False
|
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@@ -280,116 +239,128 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||||
raise Exception("Should not be called")
|
raise Exception("Should not be called")
|
||||||
|
|
||||||
|
def add_inpainting_channels_to_latents(
|
||||||
|
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||||
|
):
|
||||||
|
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||||
|
|
||||||
|
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||||
|
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||||
|
- Channel 0-3: The latents being denoised.
|
||||||
|
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||||
|
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||||
|
|
||||||
|
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||||
|
"""
|
||||||
|
# Validate assumptions about input tensor shapes.
|
||||||
|
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||||
|
assert latent_channels == 4
|
||||||
|
assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width]
|
||||||
|
assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width]
|
||||||
|
|
||||||
|
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||||
|
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||||
|
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||||
|
|
||||||
|
# Concatenate along the channel dimension.
|
||||||
|
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
num_inference_steps: int,
|
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
*,
|
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
|
seed: int,
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
additional_guidance: List[Callable] = None,
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
control_data: list[ControlNetData] | None = None,
|
||||||
control_data: List[ControlNetData] = None,
|
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
gradient_mask: Optional[bool] = False,
|
is_gradient_mask: bool = False,
|
||||||
seed: int,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Denoise the latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: The latent-space image to denoise.
|
||||||
|
- If we are inpainting, this is the initial latent image before noise has been added.
|
||||||
|
- If we are generating a new image, this should be initialized to zeros.
|
||||||
|
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||||
|
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
|
||||||
|
conditioning_data: Text conditionging data.
|
||||||
|
noise: Noise used for two purposes:
|
||||||
|
1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||||
|
2. Used to noise the `masked_latents` when inpainting.
|
||||||
|
`noise` should be None if the `latents` tensor has already been noised.
|
||||||
|
seed: The seed used to generate the noise for the denoising process.
|
||||||
|
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||||
|
same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||||
|
timesteps: The timestep schedule for the denoising process.
|
||||||
|
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||||
|
should be populated if you want noise applied *even* if timesteps is empty.
|
||||||
|
callback: A callback function that is called to report progress during the denoising process.
|
||||||
|
control_data: ControlNet data.
|
||||||
|
ip_adapter_data: IP-Adapter data.
|
||||||
|
t2i_adapter_data: T2I-Adapter data.
|
||||||
|
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
|
||||||
|
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
|
||||||
|
`latents` tensor.
|
||||||
|
TODO(ryand): Check and document the expected dtype, range, and values used to represent
|
||||||
|
foreground/background.
|
||||||
|
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
|
||||||
|
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
|
||||||
|
SD UNet model.
|
||||||
|
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
|
||||||
|
"""
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
orig_latents = latents.clone()
|
orig_latents = latents.clone()
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = init_timestep.expand(batch_size)
|
batched_init_timestep = init_timestep.expand(batch_size)
|
||||||
|
|
||||||
|
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
|
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||||
|
# full noise. Investigate the history of why this got commented out.
|
||||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
if is_inpainting_model(self.unet):
|
|
||||||
if masked_latents is None:
|
|
||||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
|
||||||
|
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
|
||||||
self._unet_forward, mask, masked_latents
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if no noise provided, noisify unmasked area based on seed
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn(
|
|
||||||
orig_latents.shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
||||||
|
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
|
||||||
|
|
||||||
try:
|
|
||||||
latents = self.generate_latents_from_embeddings(
|
|
||||||
latents,
|
|
||||||
timesteps,
|
|
||||||
conditioning_data,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
additional_guidance=additional_guidance,
|
|
||||||
control_data=control_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
|
||||||
|
|
||||||
# restore unmasked part after the last step is completed
|
|
||||||
# in-process masking happens before each step
|
|
||||||
if mask is not None:
|
|
||||||
if gradient_mask:
|
|
||||||
latents = torch.where(mask > 0, latents, orig_latents)
|
|
||||||
else:
|
|
||||||
latents = torch.lerp(
|
|
||||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
timesteps,
|
|
||||||
conditioning_data: TextConditioningData,
|
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
|
||||||
*,
|
|
||||||
additional_guidance: List[Callable] = None,
|
|
||||||
control_data: List[ControlNetData] = None,
|
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
# Handle mask guidance (a.k.a. inpainting).
|
||||||
|
mask_guidance: AddsMaskGuidance | None = None
|
||||||
|
if mask is not None and not is_inpainting_model(self.unet):
|
||||||
|
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||||
|
# apply mask guidance to the latents.
|
||||||
|
|
||||||
if timesteps.shape[0] == 0:
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
return latents
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(
|
||||||
|
orig_latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
|
|
||||||
|
mask_guidance = AddsMaskGuidance(
|
||||||
|
mask=mask,
|
||||||
|
mask_latents=orig_latents,
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
noise=noise,
|
||||||
|
is_gradient_mask=is_gradient_mask,
|
||||||
|
)
|
||||||
|
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
if use_ip_adapter or use_regional_prompting:
|
||||||
@@ -402,28 +373,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
callback(
|
||||||
callback(
|
PipelineIntermediateState(
|
||||||
PipelineIntermediateState(
|
step=-1,
|
||||||
step=-1,
|
order=self.scheduler.order,
|
||||||
order=self.scheduler.order,
|
total_steps=len(timesteps),
|
||||||
total_steps=len(timesteps),
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
latents=latents,
|
||||||
latents=latents,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# print("timesteps:", timesteps)
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
batched_t,
|
t=batched_t,
|
||||||
latents,
|
latents=latents,
|
||||||
conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
mask_guidance=mask_guidance,
|
||||||
|
mask=mask,
|
||||||
|
masked_latents=masked_latents,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
@@ -431,19 +402,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
if callback is not None:
|
callback(
|
||||||
callback(
|
PipelineIntermediateState(
|
||||||
PipelineIntermediateState(
|
step=i,
|
||||||
step=i,
|
order=self.scheduler.order,
|
||||||
order=self.scheduler.order,
|
total_steps=len(timesteps),
|
||||||
total_steps=len(timesteps),
|
timestep=int(t),
|
||||||
timestep=int(t),
|
latents=latents,
|
||||||
latents=latents,
|
predicted_original=predicted_original,
|
||||||
predicted_original=predicted_original,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return latents
|
# restore unmasked part after the last step is completed
|
||||||
|
# in-process masking happens before each step
|
||||||
|
if mask is not None:
|
||||||
|
if is_gradient_mask:
|
||||||
|
latents = torch.where(mask > 0, latents, orig_latents)
|
||||||
|
else:
|
||||||
|
latents = torch.lerp(
|
||||||
|
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(
|
def step(
|
||||||
@@ -454,19 +434,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
mask_guidance: AddsMaskGuidance | None,
|
||||||
control_data: List[ControlNetData] = None,
|
mask: torch.Tensor | None,
|
||||||
|
masked_latents: torch.Tensor | None,
|
||||||
|
control_data: list[ControlNetData] | None = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
# one day we will expand this extension point, but for now it just does denoise masking
|
# Handle masked image-to-image (a.k.a inpainting).
|
||||||
for guidance in additional_guidance:
|
if mask_guidance is not None:
|
||||||
latents = guidance(latents, timestep)
|
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||||
|
latents = mask_guidance(latents, timestep)
|
||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
@@ -514,6 +495,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
down_intrablock_additional_residuals = accum_adapter_state
|
down_intrablock_additional_residuals = accum_adapter_state
|
||||||
|
|
||||||
|
# Handle inpainting models.
|
||||||
|
if is_inpainting_model(self.unet):
|
||||||
|
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||||
|
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||||
|
# latents.
|
||||||
|
if mask is not None:
|
||||||
|
if masked_latents is None:
|
||||||
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||||
|
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||||
|
# We generate a global mask and empty original image so that we can still generate in this
|
||||||
|
# configuration.
|
||||||
|
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||||
|
# to do this.
|
||||||
|
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||||
|
# mask and original image once rather than on every denoising step.
|
||||||
|
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||||
|
latents=latent_model_input,
|
||||||
|
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||||
|
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||||
|
)
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||||
@@ -542,17 +548,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
|
||||||
for guidance in additional_guidance:
|
# again.
|
||||||
# apply the mask to any "denoised" or "pred_original_sample" fields
|
if mask_guidance is not None:
|
||||||
|
# Apply the mask to any "denoised" or "pred_original_sample" fields.
|
||||||
if hasattr(step_output, "denoised"):
|
if hasattr(step_output, "denoised"):
|
||||||
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||||
elif hasattr(step_output, "pred_original_sample"):
|
elif hasattr(step_output, "pred_original_sample"):
|
||||||
step_output.pred_original_sample = guidance(
|
step_output.pred_original_sample = mask_guidance(
|
||||||
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
|
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
|
||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
@@ -575,17 +582,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
|
||||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
|
||||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
|
||||||
# overriding things! This should get handled in a way more consistent with the other
|
|
||||||
# use of AddsMaskLatents.
|
|
||||||
latents = AddsMaskLatents(
|
|
||||||
self._unet_forward,
|
|
||||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
|
||||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
|
||||||
).add_mask_channels(latents)
|
|
||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents,
|
latents,
|
||||||
|
|||||||
170
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
170
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
|
ControlNetData,
|
||||||
|
PipelineIntermediateState,
|
||||||
|
StableDiffusionGeneratorPipeline,
|
||||||
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiDiffusionRegionConditioning:
|
||||||
|
# Region coords in latent space.
|
||||||
|
region: Tile
|
||||||
|
text_conditioning_data: TextConditioningData
|
||||||
|
control_data: list[ControlNetData]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
|
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||||
|
|
||||||
|
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||||
|
"""Validate that regional conditioning is not used."""
|
||||||
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
|
if (
|
||||||
|
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||||
|
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||||
|
|
||||||
|
def multi_diffusion_denoise(
|
||||||
|
self,
|
||||||
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||||
|
target_overlap: int,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
|
noise: Optional[torch.Tensor],
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
init_timestep: torch.Tensor,
|
||||||
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||||
|
|
||||||
|
if init_timestep.shape[0] == 0:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
batch_size, _, latent_height, latent_width = latents.shape
|
||||||
|
batched_init_timestep = init_timestep.expand(batch_size)
|
||||||
|
|
||||||
|
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
|
if noise is not None:
|
||||||
|
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||||
|
# full noise. Investigate the history of why this got commented out.
|
||||||
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||||
|
|
||||||
|
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||||
|
# cropping into regions.
|
||||||
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
|
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||||
|
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||||
|
# separate scheduler state for each region batch.
|
||||||
|
# TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
|
||||||
|
# statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
|
||||||
|
# as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
|
||||||
|
# multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
|
||||||
|
# scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
|
||||||
|
region_batch_schedulers: list[SchedulerMixin] = [
|
||||||
|
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
|
||||||
|
]
|
||||||
|
|
||||||
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
batched_t = t.expand(batch_size)
|
||||||
|
|
||||||
|
merged_latents = torch.zeros_like(latents)
|
||||||
|
merged_latents_weights = torch.zeros(
|
||||||
|
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||||
|
)
|
||||||
|
merged_pred_original: torch.Tensor | None = None
|
||||||
|
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||||
|
# Switch to the scheduler for the region batch.
|
||||||
|
self.scheduler = region_batch_schedulers[region_idx]
|
||||||
|
|
||||||
|
# Crop the inputs to the region.
|
||||||
|
region_latents = latents[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
|
||||||
|
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run the denoising step on the region.
|
||||||
|
step_output = self.step(
|
||||||
|
t=batched_t,
|
||||||
|
latents=region_latents,
|
||||||
|
conditioning_data=region_conditioning.text_conditioning_data,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
mask_guidance=None,
|
||||||
|
mask=None,
|
||||||
|
masked_latents=None,
|
||||||
|
control_data=region_conditioning.control_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the results from the region.
|
||||||
|
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
|
||||||
|
# affected tiles to achieve the target overlap.
|
||||||
|
region = region_conditioning.region
|
||||||
|
top_adjustment = max(0, region.overlap.top - target_overlap)
|
||||||
|
left_adjustment = max(0, region.overlap.left - target_overlap)
|
||||||
|
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
|
||||||
|
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
|
||||||
|
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
|
||||||
|
:, :, top_adjustment:, left_adjustment:
|
||||||
|
]
|
||||||
|
# For now, we treat every region as having the same weight.
|
||||||
|
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
|
||||||
|
|
||||||
|
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||||
|
if pred_orig_sample is not None:
|
||||||
|
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||||
|
# they all use the same scheduler.
|
||||||
|
if merged_pred_original is None:
|
||||||
|
merged_pred_original = torch.zeros_like(latents)
|
||||||
|
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
|
||||||
|
:, :, top_adjustment:, left_adjustment:
|
||||||
|
]
|
||||||
|
|
||||||
|
# Normalize the merged results.
|
||||||
|
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
|
||||||
|
# For debugging, uncomment this line to visualize the region seams:
|
||||||
|
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
|
||||||
|
predicted_original = None
|
||||||
|
if merged_pred_original is not None:
|
||||||
|
predicted_original = torch.where(
|
||||||
|
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
|
||||||
|
)
|
||||||
|
|
||||||
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=i,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=int(t),
|
||||||
|
latents=latents,
|
||||||
|
predicted_original=predicted_original,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return latents
|
||||||
@@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
for emb in [self.embedding, self.embedding_2]:
|
||||||
|
if emb is not None:
|
||||||
|
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionManager(BaseTextualInversionManager):
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||||
|
|||||||
@@ -1,29 +1,36 @@
|
|||||||
"""Context class to silence transformers and diffusers warnings."""
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from contextlib import ContextDecorator
|
||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers.utils import logging as diffusers_logging
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
|
||||||
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
|
class SilenceWarnings(ContextDecorator):
|
||||||
|
"""A context manager that disables warnings from transformers & diffusers modules while active.
|
||||||
|
|
||||||
|
As context manager:
|
||||||
|
```
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
# do something
|
# do something
|
||||||
|
```
|
||||||
|
|
||||||
|
As decorator:
|
||||||
|
```
|
||||||
|
@SilenceWarnings()
|
||||||
|
def some_function():
|
||||||
|
# do something
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
|
self._transformers_verbosity = transformers_logging.get_verbosity()
|
||||||
|
self._diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||||
transformers_logging.set_verbosity_error()
|
transformers_logging.set_verbosity_error()
|
||||||
diffusers_logging.set_verbosity_error()
|
diffusers_logging.set_verbosity_error()
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
def __exit__(self, *args) -> None:
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self._transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self._diffusers_verbosity)
|
||||||
warnings.simplefilter("default")
|
warnings.simplefilter("default")
|
||||||
|
|||||||
@@ -1,17 +1,43 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import warnings
|
import re
|
||||||
|
import unicodedata
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import logging as transformers_logging
|
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||||
|
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||||
|
underscores, or hyphens. Replace slashes with underscores.
|
||||||
|
Convert to lowercase. Also strip leading and
|
||||||
|
trailing whitespace, dashes, and underscores.
|
||||||
|
|
||||||
|
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||||
|
"""
|
||||||
|
value = str(value)
|
||||||
|
if allow_unicode:
|
||||||
|
value = unicodedata.normalize("NFKC", value)
|
||||||
|
else:
|
||||||
|
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||||
|
value = re.sub(r"[/]", "_", value.lower())
|
||||||
|
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||||
|
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(directory: Path, value: str) -> str:
|
||||||
|
"""Make a string safe to use as a filename."""
|
||||||
|
escaped_string = slugify(value)
|
||||||
|
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
|
||||||
|
return escaped_string[len(escaped_string) - max_name_length :]
|
||||||
|
|
||||||
|
|
||||||
def directory_size(directory: Path) -> int:
|
def directory_size(directory: Path) -> int:
|
||||||
"""
|
"""
|
||||||
Return the aggregate size of all files in a directory (bytes).
|
Return the aggregate size of all files in a directory (bytes).
|
||||||
@@ -51,21 +77,3 @@ class Chdir(object):
|
|||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
os.chdir(self.original)
|
os.chdir(self.original)
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
|
||||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Set verbosity to error."""
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
transformers_logging.set_verbosity_error()
|
|
||||||
diffusers_logging.set_verbosity_error()
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
"""Restore logger verbosity to state before context was entered."""
|
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
||||||
warnings.simplefilter("default")
|
|
||||||
|
|||||||
@@ -1021,7 +1021,8 @@
|
|||||||
"float": "Kommazahlen",
|
"float": "Kommazahlen",
|
||||||
"enum": "Aufzählung",
|
"enum": "Aufzählung",
|
||||||
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
|
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
|
||||||
"editMode": "Im Workflow-Editor bearbeiten"
|
"editMode": "Im Workflow-Editor bearbeiten",
|
||||||
|
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
|
||||||
},
|
},
|
||||||
"hrf": {
|
"hrf": {
|
||||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||||
|
|||||||
@@ -148,6 +148,8 @@
|
|||||||
"viewingDesc": "Review images in a large gallery view",
|
"viewingDesc": "Review images in a large gallery view",
|
||||||
"editing": "Editing",
|
"editing": "Editing",
|
||||||
"editingDesc": "Edit on the Control Layers canvas",
|
"editingDesc": "Edit on the Control Layers canvas",
|
||||||
|
"comparing": "Comparing",
|
||||||
|
"comparingDesc": "Comparing two images",
|
||||||
"enabled": "Enabled",
|
"enabled": "Enabled",
|
||||||
"disabled": "Disabled"
|
"disabled": "Disabled"
|
||||||
},
|
},
|
||||||
@@ -375,7 +377,23 @@
|
|||||||
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
||||||
"bulkDownloadFailed": "Download Failed",
|
"bulkDownloadFailed": "Download Failed",
|
||||||
"problemDeletingImages": "Problem Deleting Images",
|
"problemDeletingImages": "Problem Deleting Images",
|
||||||
"problemDeletingImagesDesc": "One or more images could not be deleted"
|
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||||
|
"viewerImage": "Viewer Image",
|
||||||
|
"compareImage": "Compare Image",
|
||||||
|
"openInViewer": "Open in Viewer",
|
||||||
|
"selectForCompare": "Select for Compare",
|
||||||
|
"selectAnImageToCompare": "Select an Image to Compare",
|
||||||
|
"slider": "Slider",
|
||||||
|
"sideBySide": "Side-by-Side",
|
||||||
|
"hover": "Hover",
|
||||||
|
"swapImages": "Swap Images",
|
||||||
|
"compareOptions": "Comparison Options",
|
||||||
|
"stretchToFit": "Stretch to Fit",
|
||||||
|
"exitCompare": "Exit Compare",
|
||||||
|
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||||
|
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||||
|
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||||
|
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"searchHotkeys": "Search Hotkeys",
|
"searchHotkeys": "Search Hotkeys",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
"settingsLabel": "Ajustes",
|
"settingsLabel": "Ajustes",
|
||||||
"img2img": "Imagen a Imagen",
|
"img2img": "Imagen a Imagen",
|
||||||
"unifiedCanvas": "Lienzo Unificado",
|
"unifiedCanvas": "Lienzo Unificado",
|
||||||
"nodes": "Editor del flujo de trabajo",
|
"nodes": "Flujos de trabajo",
|
||||||
"upload": "Subir imagen",
|
"upload": "Subir imagen",
|
||||||
"load": "Cargar",
|
"load": "Cargar",
|
||||||
"statusDisconnected": "Desconectado",
|
"statusDisconnected": "Desconectado",
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"back": "Atrás",
|
"back": "Atrás",
|
||||||
"loading": "Cargando",
|
"loading": "Cargando",
|
||||||
"postprocessing": "Tratamiento posterior",
|
"postprocessing": "Postprocesado",
|
||||||
"txt2img": "De texto a imagen",
|
"txt2img": "De texto a imagen",
|
||||||
"accept": "Aceptar",
|
"accept": "Aceptar",
|
||||||
"cancel": "Cancelar",
|
"cancel": "Cancelar",
|
||||||
@@ -42,7 +42,42 @@
|
|||||||
"copy": "Copiar",
|
"copy": "Copiar",
|
||||||
"beta": "Beta",
|
"beta": "Beta",
|
||||||
"on": "En",
|
"on": "En",
|
||||||
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
|
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
|
||||||
|
"installed": "Instalado",
|
||||||
|
"green": "Verde",
|
||||||
|
"editor": "Editor",
|
||||||
|
"orderBy": "Ordenar por",
|
||||||
|
"file": "Archivo",
|
||||||
|
"goTo": "Ir a",
|
||||||
|
"imageFailedToLoad": "No se puede cargar la imagen",
|
||||||
|
"saveAs": "Guardar Como",
|
||||||
|
"somethingWentWrong": "Algo salió mal",
|
||||||
|
"nextPage": "Página Siguiente",
|
||||||
|
"selected": "Seleccionado",
|
||||||
|
"tab": "Tabulador",
|
||||||
|
"positivePrompt": "Prompt Positivo",
|
||||||
|
"negativePrompt": "Prompt Negativo",
|
||||||
|
"error": "Error",
|
||||||
|
"format": "formato",
|
||||||
|
"unknown": "Desconocido",
|
||||||
|
"input": "Entrada",
|
||||||
|
"nodeEditor": "Editor de nodos",
|
||||||
|
"template": "Plantilla",
|
||||||
|
"prevPage": "Página Anterior",
|
||||||
|
"red": "Rojo",
|
||||||
|
"alpha": "Transparencia",
|
||||||
|
"outputs": "Salidas",
|
||||||
|
"editing": "Editando",
|
||||||
|
"learnMore": "Aprende más",
|
||||||
|
"enabled": "Activado",
|
||||||
|
"disabled": "Desactivado",
|
||||||
|
"folder": "Carpeta",
|
||||||
|
"updated": "Actualizado",
|
||||||
|
"created": "Creado",
|
||||||
|
"save": "Guardar",
|
||||||
|
"unknownError": "Error Desconocido",
|
||||||
|
"blue": "Azul",
|
||||||
|
"viewingDesc": "Revisar imágenes en una vista de galería grande"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Tamaño de la imagen",
|
"galleryImageSize": "Tamaño de la imagen",
|
||||||
@@ -467,7 +502,8 @@
|
|||||||
"about": "Acerca de",
|
"about": "Acerca de",
|
||||||
"createIssue": "Crear un problema",
|
"createIssue": "Crear un problema",
|
||||||
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
|
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
|
||||||
"mode": "Modo"
|
"mode": "Modo",
|
||||||
|
"submitSupportTicket": "Enviar Ticket de Soporte"
|
||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"zoomInNodes": "Acercar",
|
"zoomInNodes": "Acercar",
|
||||||
@@ -543,5 +579,17 @@
|
|||||||
"layers_one": "Capa",
|
"layers_one": "Capa",
|
||||||
"layers_many": "Capas",
|
"layers_many": "Capas",
|
||||||
"layers_other": "Capas"
|
"layers_other": "Capas"
|
||||||
|
},
|
||||||
|
"controlnet": {
|
||||||
|
"crop": "Cortar",
|
||||||
|
"delete": "Eliminar",
|
||||||
|
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
|
||||||
|
"duplicate": "Duplicar",
|
||||||
|
"colorMapDescription": "Genera un mapa de color desde la imagen",
|
||||||
|
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
|
||||||
|
"balanced": "Equilibrado",
|
||||||
|
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
|
||||||
|
"detectResolution": "Detectar resolución",
|
||||||
|
"beginEndStepPercentShort": "Inicio / Final %"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@
|
|||||||
"outputs": "Risultati",
|
"outputs": "Risultati",
|
||||||
"data": "Dati",
|
"data": "Dati",
|
||||||
"somethingWentWrong": "Qualcosa è andato storto",
|
"somethingWentWrong": "Qualcosa è andato storto",
|
||||||
"copyError": "$t(gallery.copy) Errore",
|
"copyError": "Errore $t(gallery.copy)",
|
||||||
"input": "Ingresso",
|
"input": "Ingresso",
|
||||||
"notInstalled": "Non $t(common.installed)",
|
"notInstalled": "Non $t(common.installed)",
|
||||||
"unknownError": "Errore sconosciuto",
|
"unknownError": "Errore sconosciuto",
|
||||||
@@ -85,7 +85,11 @@
|
|||||||
"viewing": "Visualizza",
|
"viewing": "Visualizza",
|
||||||
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
|
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
|
||||||
"editing": "Modifica",
|
"editing": "Modifica",
|
||||||
"editingDesc": "Modifica nell'area Livelli di controllo"
|
"editingDesc": "Modifica nell'area Livelli di controllo",
|
||||||
|
"enabled": "Abilitato",
|
||||||
|
"disabled": "Disabilitato",
|
||||||
|
"comparingDesc": "Confronta due immagini",
|
||||||
|
"comparing": "Confronta"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Dimensione dell'immagine",
|
"galleryImageSize": "Dimensione dell'immagine",
|
||||||
@@ -122,14 +126,30 @@
|
|||||||
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
||||||
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
||||||
"bulkDownloadFailed": "Scaricamento fallito",
|
"bulkDownloadFailed": "Scaricamento fallito",
|
||||||
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
|
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
|
||||||
|
"openInViewer": "Apri nel visualizzatore",
|
||||||
|
"selectForCompare": "Seleziona per il confronto",
|
||||||
|
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
|
||||||
|
"slider": "Cursore",
|
||||||
|
"sideBySide": "Fianco a Fianco",
|
||||||
|
"compareImage": "Immagine di confronto",
|
||||||
|
"viewerImage": "Immagine visualizzata",
|
||||||
|
"hover": "Al passaggio del mouse",
|
||||||
|
"swapImages": "Scambia le immagini",
|
||||||
|
"compareOptions": "Opzioni di confronto",
|
||||||
|
"stretchToFit": "Scala per adattare",
|
||||||
|
"exitCompare": "Esci dal confronto",
|
||||||
|
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
|
||||||
|
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
|
||||||
|
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
|
||||||
|
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||||
"appHotkeys": "Applicazione",
|
"appHotkeys": "Applicazione",
|
||||||
"generalHotkeys": "Generale",
|
"generalHotkeys": "Generale",
|
||||||
"galleryHotkeys": "Galleria",
|
"galleryHotkeys": "Galleria",
|
||||||
"unifiedCanvasHotkeys": "Tela Unificata",
|
"unifiedCanvasHotkeys": "Tela",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"title": "Invoke",
|
"title": "Invoke",
|
||||||
"desc": "Genera un'immagine"
|
"desc": "Genera un'immagine"
|
||||||
@@ -147,8 +167,8 @@
|
|||||||
"desc": "Apre e chiude il pannello delle opzioni"
|
"desc": "Apre e chiude il pannello delle opzioni"
|
||||||
},
|
},
|
||||||
"pinOptions": {
|
"pinOptions": {
|
||||||
"title": "Appunta le opzioni",
|
"title": "Fissa le opzioni",
|
||||||
"desc": "Blocca il pannello delle opzioni"
|
"desc": "Fissa il pannello delle opzioni"
|
||||||
},
|
},
|
||||||
"toggleGallery": {
|
"toggleGallery": {
|
||||||
"title": "Attiva/disattiva galleria",
|
"title": "Attiva/disattiva galleria",
|
||||||
@@ -332,14 +352,14 @@
|
|||||||
"title": "Annulla e cancella"
|
"title": "Annulla e cancella"
|
||||||
},
|
},
|
||||||
"resetOptionsAndGallery": {
|
"resetOptionsAndGallery": {
|
||||||
"title": "Ripristina Opzioni e Galleria",
|
"title": "Ripristina le opzioni e la galleria",
|
||||||
"desc": "Reimposta le opzioni e i pannelli della galleria"
|
"desc": "Reimposta i pannelli delle opzioni e della galleria"
|
||||||
},
|
},
|
||||||
"searchHotkeys": "Cerca tasti di scelta rapida",
|
"searchHotkeys": "Cerca tasti di scelta rapida",
|
||||||
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
|
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
|
||||||
"toggleOptionsAndGallery": {
|
"toggleOptionsAndGallery": {
|
||||||
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
|
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
|
||||||
"title": "Attiva/disattiva le Opzioni e la Galleria"
|
"title": "Attiva/disattiva le opzioni e la galleria"
|
||||||
},
|
},
|
||||||
"clearSearch": "Cancella ricerca",
|
"clearSearch": "Cancella ricerca",
|
||||||
"remixImage": {
|
"remixImage": {
|
||||||
@@ -348,7 +368,7 @@
|
|||||||
},
|
},
|
||||||
"toggleViewer": {
|
"toggleViewer": {
|
||||||
"title": "Attiva/disattiva il visualizzatore di immagini",
|
"title": "Attiva/disattiva il visualizzatore di immagini",
|
||||||
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
|
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"modelManager": {
|
"modelManager": {
|
||||||
@@ -378,7 +398,7 @@
|
|||||||
"convertToDiffusers": "Converti in Diffusori",
|
"convertToDiffusers": "Converti in Diffusori",
|
||||||
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
|
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
|
||||||
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
|
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
|
||||||
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
|
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
|
||||||
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
|
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
|
||||||
"modelConverted": "Modello convertito",
|
"modelConverted": "Modello convertito",
|
||||||
"alpha": "Alpha",
|
"alpha": "Alpha",
|
||||||
@@ -528,7 +548,7 @@
|
|||||||
"layer": {
|
"layer": {
|
||||||
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
|
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
|
||||||
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
|
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
|
||||||
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
|
"controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
|
||||||
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
|
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
|
||||||
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
|
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
|
||||||
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
|
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
|
||||||
@@ -606,25 +626,25 @@
|
|||||||
"canvasMerged": "Tela unita",
|
"canvasMerged": "Tela unita",
|
||||||
"sentToImageToImage": "Inviato a Generazione da immagine",
|
"sentToImageToImage": "Inviato a Generazione da immagine",
|
||||||
"sentToUnifiedCanvas": "Inviato alla Tela",
|
"sentToUnifiedCanvas": "Inviato alla Tela",
|
||||||
"parametersNotSet": "Parametri non impostati",
|
"parametersNotSet": "Parametri non richiamati",
|
||||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||||
"serverError": "Errore del Server",
|
"serverError": "Errore del Server",
|
||||||
"connected": "Connesso al Server",
|
"connected": "Connesso al server",
|
||||||
"canceled": "Elaborazione annullata",
|
"canceled": "Elaborazione annullata",
|
||||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||||
"parameterSet": "{{parameter}} impostato",
|
"parameterSet": "Parametro richiamato",
|
||||||
"parameterNotSet": "{{parameter}} non impostato",
|
"parameterNotSet": "Parametro non richiamato",
|
||||||
"problemCopyingImage": "Impossibile copiare l'immagine",
|
"problemCopyingImage": "Impossibile copiare l'immagine",
|
||||||
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
|
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
|
||||||
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
|
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
|
||||||
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
|
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
|
||||||
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
|
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
|
||||||
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
|
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
|
||||||
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
|
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
|
||||||
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
|
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
|
||||||
"canvasCopiedClipboard": "Tela copiata negli appunti",
|
"canvasCopiedClipboard": "Tela copiata negli appunti",
|
||||||
"maskSavedAssets": "Maschera salvata nelle risorse",
|
"maskSavedAssets": "Maschera salvata nelle risorse",
|
||||||
"problemDownloadingCanvas": "Problema durante il download della tela",
|
"problemDownloadingCanvas": "Problema durante lo scarico della tela",
|
||||||
"problemMergingCanvas": "Problema nell'unione delle tele",
|
"problemMergingCanvas": "Problema nell'unione delle tele",
|
||||||
"imageUploaded": "Immagine caricata",
|
"imageUploaded": "Immagine caricata",
|
||||||
"addedToBoard": "Aggiunto alla bacheca",
|
"addedToBoard": "Aggiunto alla bacheca",
|
||||||
@@ -658,7 +678,17 @@
|
|||||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||||
"prunedQueue": "Coda ripulita",
|
"prunedQueue": "Coda ripulita",
|
||||||
"modelImportCanceled": "Importazione del modello annullata",
|
"modelImportCanceled": "Importazione del modello annullata",
|
||||||
"parameters": "Parametri"
|
"parameters": "Parametri",
|
||||||
|
"parameterSetDesc": "{{parameter}} richiamato",
|
||||||
|
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
|
||||||
|
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
|
||||||
|
"parametersSet": "Parametri richiamati",
|
||||||
|
"errorCopied": "Errore copiato",
|
||||||
|
"outOfMemoryError": "Errore di memoria esaurita",
|
||||||
|
"baseModelChanged": "Modello base modificato",
|
||||||
|
"sessionRef": "Sessione: {{sessionId}}",
|
||||||
|
"somethingWentWrong": "Qualcosa è andato storto",
|
||||||
|
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@@ -674,7 +704,7 @@
|
|||||||
"layer": "Livello",
|
"layer": "Livello",
|
||||||
"base": "Base",
|
"base": "Base",
|
||||||
"mask": "Maschera",
|
"mask": "Maschera",
|
||||||
"maskingOptions": "Opzioni di mascheramento",
|
"maskingOptions": "Opzioni maschera",
|
||||||
"enableMask": "Abilita maschera",
|
"enableMask": "Abilita maschera",
|
||||||
"preserveMaskedArea": "Mantieni area mascherata",
|
"preserveMaskedArea": "Mantieni area mascherata",
|
||||||
"clearMask": "Cancella maschera (Shift+C)",
|
"clearMask": "Cancella maschera (Shift+C)",
|
||||||
@@ -745,7 +775,8 @@
|
|||||||
"mode": "Modalità",
|
"mode": "Modalità",
|
||||||
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
|
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
|
||||||
"createIssue": "Segnala un problema",
|
"createIssue": "Segnala un problema",
|
||||||
"about": "Informazioni"
|
"about": "Informazioni",
|
||||||
|
"submitSupportTicket": "Invia ticket di supporto"
|
||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"zoomOutNodes": "Rimpicciolire",
|
"zoomOutNodes": "Rimpicciolire",
|
||||||
@@ -790,7 +821,7 @@
|
|||||||
"workflowNotes": "Note",
|
"workflowNotes": "Note",
|
||||||
"versionUnknown": " Versione sconosciuta",
|
"versionUnknown": " Versione sconosciuta",
|
||||||
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
|
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
|
||||||
"updateApp": "Aggiorna App",
|
"updateApp": "Aggiorna Applicazione",
|
||||||
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
|
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
|
||||||
"updateNode": "Aggiorna nodo",
|
"updateNode": "Aggiorna nodo",
|
||||||
"version": "Versione",
|
"version": "Versione",
|
||||||
@@ -882,11 +913,14 @@
|
|||||||
"missingNode": "Nodo di invocazione mancante",
|
"missingNode": "Nodo di invocazione mancante",
|
||||||
"missingInvocationTemplate": "Modello di invocazione mancante",
|
"missingInvocationTemplate": "Modello di invocazione mancante",
|
||||||
"missingFieldTemplate": "Modello di campo mancante",
|
"missingFieldTemplate": "Modello di campo mancante",
|
||||||
"singleFieldType": "{{name}} (Singola)"
|
"singleFieldType": "{{name}} (Singola)",
|
||||||
|
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
|
||||||
|
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
|
||||||
|
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||||
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
|
"menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
|
||||||
"cancel": "Annulla",
|
"cancel": "Annulla",
|
||||||
"addBoard": "Aggiungi Bacheca",
|
"addBoard": "Aggiungi Bacheca",
|
||||||
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
|
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
|
||||||
@@ -898,7 +932,7 @@
|
|||||||
"myBoard": "Bacheca",
|
"myBoard": "Bacheca",
|
||||||
"searchBoard": "Cerca bacheche ...",
|
"searchBoard": "Cerca bacheche ...",
|
||||||
"noMatching": "Nessuna bacheca corrispondente",
|
"noMatching": "Nessuna bacheca corrispondente",
|
||||||
"selectBoard": "Seleziona una Bacheca",
|
"selectBoard": "Seleziona una bacheca",
|
||||||
"uncategorized": "Non categorizzato",
|
"uncategorized": "Non categorizzato",
|
||||||
"downloadBoard": "Scarica la bacheca",
|
"downloadBoard": "Scarica la bacheca",
|
||||||
"deleteBoardOnly": "solo la Bacheca",
|
"deleteBoardOnly": "solo la Bacheca",
|
||||||
@@ -919,7 +953,7 @@
|
|||||||
"control": "Controllo",
|
"control": "Controllo",
|
||||||
"crop": "Ritaglia",
|
"crop": "Ritaglia",
|
||||||
"depthMidas": "Profondità (Midas)",
|
"depthMidas": "Profondità (Midas)",
|
||||||
"detectResolution": "Rileva risoluzione",
|
"detectResolution": "Rileva la risoluzione",
|
||||||
"controlMode": "Modalità di controllo",
|
"controlMode": "Modalità di controllo",
|
||||||
"cannyDescription": "Canny rilevamento bordi",
|
"cannyDescription": "Canny rilevamento bordi",
|
||||||
"depthZoe": "Profondità (Zoe)",
|
"depthZoe": "Profondità (Zoe)",
|
||||||
@@ -930,7 +964,7 @@
|
|||||||
"showAdvanced": "Mostra opzioni Avanzate",
|
"showAdvanced": "Mostra opzioni Avanzate",
|
||||||
"bgth": "Soglia rimozione sfondo",
|
"bgth": "Soglia rimozione sfondo",
|
||||||
"importImageFromCanvas": "Importa immagine dalla Tela",
|
"importImageFromCanvas": "Importa immagine dalla Tela",
|
||||||
"lineartDescription": "Converte l'immagine in lineart",
|
"lineartDescription": "Converte l'immagine in linea",
|
||||||
"importMaskFromCanvas": "Importa maschera dalla Tela",
|
"importMaskFromCanvas": "Importa maschera dalla Tela",
|
||||||
"hideAdvanced": "Nascondi opzioni avanzate",
|
"hideAdvanced": "Nascondi opzioni avanzate",
|
||||||
"resetControlImage": "Reimposta immagine di controllo",
|
"resetControlImage": "Reimposta immagine di controllo",
|
||||||
@@ -946,7 +980,7 @@
|
|||||||
"pidiDescription": "Elaborazione immagini PIDI",
|
"pidiDescription": "Elaborazione immagini PIDI",
|
||||||
"fill": "Riempie",
|
"fill": "Riempie",
|
||||||
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
|
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
|
||||||
"lineartAnimeDescription": "Elaborazione lineart in stile anime",
|
"lineartAnimeDescription": "Elaborazione linea in stile anime",
|
||||||
"imageResolution": "Risoluzione dell'immagine",
|
"imageResolution": "Risoluzione dell'immagine",
|
||||||
"colorMap": "Colore",
|
"colorMap": "Colore",
|
||||||
"lowThreshold": "Soglia inferiore",
|
"lowThreshold": "Soglia inferiore",
|
||||||
|
|||||||
@@ -87,7 +87,11 @@
|
|||||||
"viewing": "Просмотр",
|
"viewing": "Просмотр",
|
||||||
"editing": "Редактирование",
|
"editing": "Редактирование",
|
||||||
"viewingDesc": "Просмотр изображений в режиме большой галереи",
|
"viewingDesc": "Просмотр изображений в режиме большой галереи",
|
||||||
"editingDesc": "Редактировать на холсте слоёв управления"
|
"editingDesc": "Редактировать на холсте слоёв управления",
|
||||||
|
"enabled": "Включено",
|
||||||
|
"disabled": "Отключено",
|
||||||
|
"comparingDesc": "Сравнение двух изображений",
|
||||||
|
"comparing": "Сравнение"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Размер изображений",
|
"galleryImageSize": "Размер изображений",
|
||||||
@@ -124,7 +128,23 @@
|
|||||||
"bulkDownloadRequested": "Подготовка к скачиванию",
|
"bulkDownloadRequested": "Подготовка к скачиванию",
|
||||||
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
||||||
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
|
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
|
||||||
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
|
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
|
||||||
|
"openInViewer": "Открыть в просмотрщике",
|
||||||
|
"selectForCompare": "Выбрать для сравнения",
|
||||||
|
"hover": "Наведение",
|
||||||
|
"swapImages": "Поменять местами",
|
||||||
|
"stretchToFit": "Растягивание до нужного размера",
|
||||||
|
"exitCompare": "Выйти из сравнения",
|
||||||
|
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
|
||||||
|
"compareImage": "Сравнить изображение",
|
||||||
|
"viewerImage": "Изображение просмотрщика",
|
||||||
|
"selectAnImageToCompare": "Выберите изображение для сравнения",
|
||||||
|
"slider": "Слайдер",
|
||||||
|
"sideBySide": "Бок о бок",
|
||||||
|
"compareOptions": "Варианты сравнения",
|
||||||
|
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
|
||||||
|
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
|
||||||
|
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Горячие клавиши",
|
"keyboardShortcuts": "Горячие клавиши",
|
||||||
@@ -528,7 +548,20 @@
|
|||||||
"missingFieldTemplate": "Отсутствует шаблон поля",
|
"missingFieldTemplate": "Отсутствует шаблон поля",
|
||||||
"addingImagesTo": "Добавление изображений в",
|
"addingImagesTo": "Добавление изображений в",
|
||||||
"invoke": "Создать",
|
"invoke": "Создать",
|
||||||
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
|
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
|
||||||
|
"layer": {
|
||||||
|
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
|
||||||
|
"ipAdapterNoModelSelected": "IP адаптер не выбран",
|
||||||
|
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
|
||||||
|
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
|
||||||
|
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
|
||||||
|
"initialImageNoImageSelected": "начальное изображение не выбрано",
|
||||||
|
"rgNoRegion": "регион не выбран",
|
||||||
|
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
|
||||||
|
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
|
||||||
|
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
|
||||||
|
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"isAllowedToUpscale": {
|
"isAllowedToUpscale": {
|
||||||
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
|
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
|
||||||
@@ -606,12 +639,12 @@
|
|||||||
"connected": "Подключено к серверу",
|
"connected": "Подключено к серверу",
|
||||||
"canceled": "Обработка отменена",
|
"canceled": "Обработка отменена",
|
||||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||||
"parameterNotSet": "Параметр {{parameter}} не задан",
|
"parameterNotSet": "Параметр не задан",
|
||||||
"parameterSet": "Параметр {{parameter}} задан",
|
"parameterSet": "Параметр задан",
|
||||||
"problemCopyingImage": "Не удается скопировать изображение",
|
"problemCopyingImage": "Не удается скопировать изображение",
|
||||||
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
|
"baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
|
||||||
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
|
"baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
|
||||||
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
|
"baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
|
||||||
"imageSavingFailed": "Не удалось сохранить изображение",
|
"imageSavingFailed": "Не удалось сохранить изображение",
|
||||||
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
|
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
|
||||||
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
|
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
|
||||||
@@ -652,7 +685,17 @@
|
|||||||
"resetInitialImage": "Сбросить начальное изображение",
|
"resetInitialImage": "Сбросить начальное изображение",
|
||||||
"prunedQueue": "Урезанная очередь",
|
"prunedQueue": "Урезанная очередь",
|
||||||
"modelImportCanceled": "Импорт модели отменен",
|
"modelImportCanceled": "Импорт модели отменен",
|
||||||
"parameters": "Параметры"
|
"parameters": "Параметры",
|
||||||
|
"parameterSetDesc": "Задан {{parameter}}",
|
||||||
|
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
|
||||||
|
"baseModelChanged": "Базовая модель сменена",
|
||||||
|
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
|
||||||
|
"parametersSet": "Параметры заданы",
|
||||||
|
"errorCopied": "Ошибка скопирована",
|
||||||
|
"sessionRef": "Сессия: {{sessionId}}",
|
||||||
|
"outOfMemoryError": "Ошибка нехватки памяти",
|
||||||
|
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
|
||||||
|
"somethingWentWrong": "Что-то пошло не так"
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@@ -739,7 +782,8 @@
|
|||||||
"loadMore": "Загрузить больше",
|
"loadMore": "Загрузить больше",
|
||||||
"resetUI": "$t(accessibility.reset) интерфейс",
|
"resetUI": "$t(accessibility.reset) интерфейс",
|
||||||
"createIssue": "Сообщить о проблеме",
|
"createIssue": "Сообщить о проблеме",
|
||||||
"about": "Об этом"
|
"about": "Об этом",
|
||||||
|
"submitSupportTicket": "Отправить тикет в службу поддержки"
|
||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"zoomInNodes": "Увеличьте масштаб",
|
"zoomInNodes": "Увеличьте масштаб",
|
||||||
@@ -832,7 +876,7 @@
|
|||||||
"workflowName": "Название",
|
"workflowName": "Название",
|
||||||
"collection": "Коллекция",
|
"collection": "Коллекция",
|
||||||
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
|
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
|
||||||
"collectionFieldType": "Коллекция {{name}}",
|
"collectionFieldType": "{{name}} (Коллекция)",
|
||||||
"workflowNotes": "Примечания",
|
"workflowNotes": "Примечания",
|
||||||
"string": "Строка",
|
"string": "Строка",
|
||||||
"unknownNodeType": "Неизвестный тип узла",
|
"unknownNodeType": "Неизвестный тип узла",
|
||||||
@@ -848,7 +892,7 @@
|
|||||||
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
|
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
|
||||||
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
|
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
|
||||||
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
|
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
|
||||||
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
|
"collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
|
||||||
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
|
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
|
||||||
"nodeVersion": "Версия узла",
|
"nodeVersion": "Версия узла",
|
||||||
"loadingNodes": "Загрузка узлов...",
|
"loadingNodes": "Загрузка узлов...",
|
||||||
@@ -870,7 +914,16 @@
|
|||||||
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
|
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
|
||||||
"graph": "График",
|
"graph": "График",
|
||||||
"showEdgeLabels": "Показать метки на ребрах",
|
"showEdgeLabels": "Показать метки на ребрах",
|
||||||
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
|
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
|
||||||
|
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
|
||||||
|
"missingNode": "Отсутствует узел вызова",
|
||||||
|
"missingInvocationTemplate": "Отсутствует шаблон вызова",
|
||||||
|
"missingFieldTemplate": "Отсутствующий шаблон поля",
|
||||||
|
"singleFieldType": "{{name}} (Один)",
|
||||||
|
"noGraph": "Нет графика",
|
||||||
|
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
|
||||||
|
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
|
||||||
|
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"amult": "a_mult",
|
"amult": "a_mult",
|
||||||
@@ -1441,7 +1494,16 @@
|
|||||||
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
|
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
|
||||||
"item": "Элемент",
|
"item": "Элемент",
|
||||||
"graphFailedToQueue": "Не удалось поставить график в очередь",
|
"graphFailedToQueue": "Не удалось поставить график в очередь",
|
||||||
"openQueue": "Открыть очередь"
|
"openQueue": "Открыть очередь",
|
||||||
|
"prompts_one": "Запрос",
|
||||||
|
"prompts_few": "Запроса",
|
||||||
|
"prompts_many": "Запросов",
|
||||||
|
"iterations_one": "Итерация",
|
||||||
|
"iterations_few": "Итерации",
|
||||||
|
"iterations_many": "Итераций",
|
||||||
|
"generations_one": "Генерация",
|
||||||
|
"generations_few": "Генерации",
|
||||||
|
"generations_many": "Генераций"
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
"refinerStart": "Запуск доработчика",
|
"refinerStart": "Запуск доработчика",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"common": {
|
"common": {
|
||||||
"nodes": "節點",
|
"nodes": "工作流程",
|
||||||
"img2img": "圖片轉圖片",
|
"img2img": "圖片轉圖片",
|
||||||
"statusDisconnected": "已中斷連線",
|
"statusDisconnected": "已中斷連線",
|
||||||
"back": "返回",
|
"back": "返回",
|
||||||
@@ -11,17 +11,239 @@
|
|||||||
"reportBugLabel": "回報錯誤",
|
"reportBugLabel": "回報錯誤",
|
||||||
"githubLabel": "GitHub",
|
"githubLabel": "GitHub",
|
||||||
"hotkeysLabel": "快捷鍵",
|
"hotkeysLabel": "快捷鍵",
|
||||||
"languagePickerLabel": "切換語言",
|
"languagePickerLabel": "語言",
|
||||||
"unifiedCanvas": "統一畫布",
|
"unifiedCanvas": "統一畫布",
|
||||||
"cancel": "取消",
|
"cancel": "取消",
|
||||||
"txt2img": "文字轉圖片"
|
"txt2img": "文字轉圖片",
|
||||||
|
"controlNet": "ControlNet",
|
||||||
|
"advanced": "進階",
|
||||||
|
"folder": "資料夾",
|
||||||
|
"installed": "已安裝",
|
||||||
|
"accept": "接受",
|
||||||
|
"goTo": "前往",
|
||||||
|
"input": "輸入",
|
||||||
|
"random": "隨機",
|
||||||
|
"selected": "已選擇",
|
||||||
|
"communityLabel": "社群",
|
||||||
|
"loading": "載入中",
|
||||||
|
"delete": "刪除",
|
||||||
|
"copy": "複製",
|
||||||
|
"error": "錯誤",
|
||||||
|
"file": "檔案",
|
||||||
|
"format": "格式",
|
||||||
|
"imageFailedToLoad": "無法載入圖片"
|
||||||
},
|
},
|
||||||
"accessibility": {
|
"accessibility": {
|
||||||
"invokeProgressBar": "Invoke 進度條",
|
"invokeProgressBar": "Invoke 進度條",
|
||||||
"uploadImage": "上傳圖片",
|
"uploadImage": "上傳圖片",
|
||||||
"reset": "重設",
|
"reset": "重置",
|
||||||
"nextImage": "下一張圖片",
|
"nextImage": "下一張圖片",
|
||||||
"previousImage": "上一張圖片",
|
"previousImage": "上一張圖片",
|
||||||
"menu": "選單"
|
"menu": "選單",
|
||||||
|
"loadMore": "載入更多",
|
||||||
|
"about": "關於",
|
||||||
|
"createIssue": "建立問題",
|
||||||
|
"resetUI": "$t(accessibility.reset) 介面",
|
||||||
|
"submitSupportTicket": "提交支援工單",
|
||||||
|
"mode": "模式"
|
||||||
|
},
|
||||||
|
"boards": {
|
||||||
|
"loading": "載入中…",
|
||||||
|
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
|
||||||
|
"move": "移動",
|
||||||
|
"uncategorized": "未分類",
|
||||||
|
"cancel": "取消"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"workflow": "工作流程",
|
||||||
|
"steps": "步數",
|
||||||
|
"model": "模型",
|
||||||
|
"seed": "種子",
|
||||||
|
"vae": "VAE",
|
||||||
|
"seamless": "無縫",
|
||||||
|
"metadata": "元數據",
|
||||||
|
"width": "寬度",
|
||||||
|
"height": "高度"
|
||||||
|
},
|
||||||
|
"accordions": {
|
||||||
|
"control": {
|
||||||
|
"title": "控制"
|
||||||
|
},
|
||||||
|
"compositing": {
|
||||||
|
"title": "合成"
|
||||||
|
},
|
||||||
|
"advanced": {
|
||||||
|
"title": "進階",
|
||||||
|
"options": "$t(accordions.advanced.title) 選項"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"hotkeys": {
|
||||||
|
"nodesHotkeys": "節點",
|
||||||
|
"cancel": {
|
||||||
|
"title": "取消"
|
||||||
|
},
|
||||||
|
"generalHotkeys": "一般",
|
||||||
|
"keyboardShortcuts": "快捷鍵",
|
||||||
|
"appHotkeys": "應用程式"
|
||||||
|
},
|
||||||
|
"modelManager": {
|
||||||
|
"advanced": "進階",
|
||||||
|
"allModels": "全部模型",
|
||||||
|
"variant": "變體",
|
||||||
|
"config": "配置",
|
||||||
|
"model": "模型",
|
||||||
|
"selected": "已選擇",
|
||||||
|
"huggingFace": "HuggingFace",
|
||||||
|
"install": "安裝",
|
||||||
|
"metadata": "元數據",
|
||||||
|
"delete": "刪除",
|
||||||
|
"description": "描述",
|
||||||
|
"cancel": "取消",
|
||||||
|
"convert": "轉換",
|
||||||
|
"manual": "手動",
|
||||||
|
"none": "無",
|
||||||
|
"name": "名稱",
|
||||||
|
"load": "載入",
|
||||||
|
"height": "高度",
|
||||||
|
"width": "寬度",
|
||||||
|
"search": "搜尋",
|
||||||
|
"vae": "VAE",
|
||||||
|
"settings": "設定"
|
||||||
|
},
|
||||||
|
"controlnet": {
|
||||||
|
"mlsd": "M-LSD",
|
||||||
|
"canny": "Canny",
|
||||||
|
"duplicate": "重複",
|
||||||
|
"none": "無",
|
||||||
|
"pidi": "PIDI",
|
||||||
|
"h": "H",
|
||||||
|
"balanced": "平衡",
|
||||||
|
"crop": "裁切",
|
||||||
|
"processor": "處理器",
|
||||||
|
"control": "控制",
|
||||||
|
"f": "F",
|
||||||
|
"lineart": "線條藝術",
|
||||||
|
"w": "W",
|
||||||
|
"hed": "HED",
|
||||||
|
"delete": "刪除"
|
||||||
|
},
|
||||||
|
"queue": {
|
||||||
|
"queue": "佇列",
|
||||||
|
"canceled": "已取消",
|
||||||
|
"failed": "已失敗",
|
||||||
|
"completed": "已完成",
|
||||||
|
"cancel": "取消",
|
||||||
|
"session": "工作階段",
|
||||||
|
"batch": "批量",
|
||||||
|
"item": "項目",
|
||||||
|
"completedIn": "完成於",
|
||||||
|
"notReady": "無法排隊"
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"cancel": {
|
||||||
|
"cancel": "取消"
|
||||||
|
},
|
||||||
|
"height": "高度",
|
||||||
|
"type": "類型",
|
||||||
|
"symmetry": "對稱性",
|
||||||
|
"images": "圖片",
|
||||||
|
"width": "寬度",
|
||||||
|
"coherenceMode": "模式",
|
||||||
|
"seed": "種子",
|
||||||
|
"general": "一般",
|
||||||
|
"strength": "強度",
|
||||||
|
"steps": "步數",
|
||||||
|
"info": "資訊"
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"beta": "Beta",
|
||||||
|
"developer": "開發者",
|
||||||
|
"general": "一般",
|
||||||
|
"models": "模型"
|
||||||
|
},
|
||||||
|
"popovers": {
|
||||||
|
"paramModel": {
|
||||||
|
"heading": "模型"
|
||||||
|
},
|
||||||
|
"compositingCoherenceMode": {
|
||||||
|
"heading": "模式"
|
||||||
|
},
|
||||||
|
"paramSteps": {
|
||||||
|
"heading": "步數"
|
||||||
|
},
|
||||||
|
"controlNetProcessor": {
|
||||||
|
"heading": "處理器"
|
||||||
|
},
|
||||||
|
"paramVAE": {
|
||||||
|
"heading": "VAE"
|
||||||
|
},
|
||||||
|
"paramHeight": {
|
||||||
|
"heading": "高度"
|
||||||
|
},
|
||||||
|
"paramSeed": {
|
||||||
|
"heading": "種子"
|
||||||
|
},
|
||||||
|
"paramWidth": {
|
||||||
|
"heading": "寬度"
|
||||||
|
},
|
||||||
|
"refinerSteps": {
|
||||||
|
"heading": "步數"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"unifiedCanvas": {
|
||||||
|
"undo": "復原",
|
||||||
|
"mask": "遮罩",
|
||||||
|
"eraser": "橡皮擦",
|
||||||
|
"antialiasing": "抗鋸齒",
|
||||||
|
"redo": "重做",
|
||||||
|
"layer": "圖層",
|
||||||
|
"accept": "接受",
|
||||||
|
"brush": "刷子",
|
||||||
|
"move": "移動",
|
||||||
|
"brushSize": "大小"
|
||||||
|
},
|
||||||
|
"nodes": {
|
||||||
|
"workflowName": "名稱",
|
||||||
|
"notes": "註釋",
|
||||||
|
"workflowVersion": "版本",
|
||||||
|
"workflowNotes": "註釋",
|
||||||
|
"executionStateError": "錯誤",
|
||||||
|
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
|
||||||
|
"integer": "整數",
|
||||||
|
"workflow": "工作流程",
|
||||||
|
"enum": "枚舉",
|
||||||
|
"edit": "編輯",
|
||||||
|
"string": "字串",
|
||||||
|
"workflowTags": "標籤",
|
||||||
|
"node": "節點",
|
||||||
|
"boolean": "布林值",
|
||||||
|
"workflowAuthor": "作者",
|
||||||
|
"version": "版本",
|
||||||
|
"executionStateCompleted": "已完成",
|
||||||
|
"edge": "邊緣",
|
||||||
|
"versionUnknown": " 版本未知"
|
||||||
|
},
|
||||||
|
"sdxl": {
|
||||||
|
"steps": "步數",
|
||||||
|
"loading": "載入中…",
|
||||||
|
"refiner": "精煉器"
|
||||||
|
},
|
||||||
|
"gallery": {
|
||||||
|
"copy": "複製",
|
||||||
|
"download": "下載",
|
||||||
|
"loading": "載入中"
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"tabs": {
|
||||||
|
"models": "模型",
|
||||||
|
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||||
|
"queue": "佇列"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"loading": "載入中"
|
||||||
|
},
|
||||||
|
"workflows": {
|
||||||
|
"name": "名稱"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
return extendTheme({
|
return extendTheme({
|
||||||
..._theme,
|
..._theme,
|
||||||
direction,
|
direction,
|
||||||
|
shadows: {
|
||||||
|
..._theme.shadows,
|
||||||
|
selectedForCompare:
|
||||||
|
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||||
|
hoverSelectedForCompare:
|
||||||
|
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}, [direction]);
|
}, [direction]);
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import {
|
|||||||
isControlAdapterLayer,
|
isControlAdapterLayer,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@@ -23,7 +22,13 @@ import type { BatchConfig } from 'services/api/types';
|
|||||||
import { socketInvocationComplete } from 'services/events/actions';
|
import { socketInvocationComplete } from 'services/events/actions';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
|
const matcher = isAnyOf(
|
||||||
|
caLayerImageChanged,
|
||||||
|
caLayerProcessedImageChanged,
|
||||||
|
caLayerProcessorConfigChanged,
|
||||||
|
caLayerModelChanged,
|
||||||
|
caLayerRecalled
|
||||||
|
);
|
||||||
|
|
||||||
const DEBOUNCE_MS = 300;
|
const DEBOUNCE_MS = 300;
|
||||||
const log = logger('session');
|
const log = logger('session');
|
||||||
@@ -74,9 +79,10 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
|||||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
||||||
|
|
||||||
const image = layer.controlAdapter.image;
|
const image = layer.controlAdapter.image;
|
||||||
|
const processedImage = layer.controlAdapter.processedImage;
|
||||||
const config = layer.controlAdapter.processorConfig;
|
const config = layer.controlAdapter.processorConfig;
|
||||||
|
|
||||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
|
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
|
||||||
// Neither config nor image have changed, we can bail
|
// Neither config nor image have changed, we can bail
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -139,7 +145,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
|||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
assert(
|
assert(
|
||||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||||
);
|
);
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import {
|
|||||||
selectControlAdapterById,
|
selectControlAdapterById,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
|||||||
);
|
);
|
||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
|
||||||
// Wait for the ImageDTO to be received
|
// Wait for the ImageDTO to be received
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
import { imagesSelectors } from 'services/api/util';
|
import { imagesSelectors } from 'services/api/util';
|
||||||
@@ -11,6 +11,7 @@ export const galleryImageClicked = createAction<{
|
|||||||
shiftKey: boolean;
|
shiftKey: boolean;
|
||||||
ctrlKey: boolean;
|
ctrlKey: boolean;
|
||||||
metaKey: boolean;
|
metaKey: boolean;
|
||||||
|
altKey: boolean;
|
||||||
}>('gallery/imageClicked');
|
}>('gallery/imageClicked');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -28,7 +29,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: galleryImageClicked,
|
actionCreator: galleryImageClicked,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
|
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListImagesQueryArgs(state);
|
const queryArgs = selectListImagesQueryArgs(state);
|
||||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||||
@@ -41,7 +42,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
|||||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||||
const selection = state.gallery.selection;
|
const selection = state.gallery.selection;
|
||||||
|
|
||||||
if (shiftKey) {
|
if (altKey) {
|
||||||
|
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
|
} else {
|
||||||
|
dispatch(imageToCompareChanged(imageDTO));
|
||||||
|
}
|
||||||
|
} else if (shiftKey) {
|
||||||
const rangeEndImageName = imageDTO.image_name;
|
const rangeEndImageName = imageDTO.image_name;
|
||||||
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
||||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ import {
|
|||||||
rgLayerIPAdapterImageChanged,
|
rgLayerIPAdapterImageChanged,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||||
|
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@@ -30,6 +31,9 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const log = logger('dnd');
|
const log = logger('dnd');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
if (!isValidDrop(overData, activeData)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
log.debug({ activeData, overData }, 'Image dropped');
|
log.debug({ activeData, overData }, 'Image dropped');
|
||||||
@@ -50,6 +54,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||||
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,24 +187,18 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO
|
* Image selected for compare
|
||||||
* Image selection dropped on node image collection field
|
|
||||||
*/
|
*/
|
||||||
// if (
|
if (
|
||||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
overData.actionType === 'SELECT_FOR_COMPARE' &&
|
||||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
// activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
// ) {
|
) {
|
||||||
// const { fieldName, nodeId } = overData.context;
|
const { imageDTO } = activeData.payload;
|
||||||
// dispatch(
|
dispatch(imageToCompareChanged(imageDTO));
|
||||||
// fieldValueChanged({
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
// nodeId,
|
return;
|
||||||
// fieldName,
|
}
|
||||||
// value: [activeData.payload.imageDTO],
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on user board
|
* Image dropped on user board
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
@@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
|||||||
|
|
||||||
const { result, invocation_source_id } = data;
|
const { result, invocation_source_id } = data;
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
|
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||||
const { image_name } = data.result.image;
|
const { image_name } = data.result.image;
|
||||||
const { canvas, gallery } = getState();
|
const { canvas, gallery } = getState();
|
||||||
|
|
||||||
|
|||||||
@@ -5,43 +5,122 @@ import {
|
|||||||
socketModelInstallCancelled,
|
socketModelInstallCancelled,
|
||||||
socketModelInstallComplete,
|
socketModelInstallComplete,
|
||||||
socketModelInstallDownloadProgress,
|
socketModelInstallDownloadProgress,
|
||||||
|
socketModelInstallDownloadsComplete,
|
||||||
|
socketModelInstallDownloadStarted,
|
||||||
socketModelInstallError,
|
socketModelInstallError,
|
||||||
|
socketModelInstallStarted,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
|
||||||
|
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
|
||||||
|
* downloaded and is being "physically" installed.
|
||||||
|
*
|
||||||
|
* Note: the download events are only fired for remote model installs, not local.
|
||||||
|
*
|
||||||
|
* Here's the expected flow:
|
||||||
|
* - API receives install request, model manager preps the install
|
||||||
|
* - `model_install_download_started` fired when the download starts
|
||||||
|
* - `model_install_download_progress` fired continually until the download is complete
|
||||||
|
* - `model_install_download_complete` fired when the download is complete
|
||||||
|
* - `model_install_started` fired when the "physical" installation starts
|
||||||
|
* - `model_install_complete` fired when the installation is complete
|
||||||
|
* - `model_install_cancelled` fired if the installation is cancelled
|
||||||
|
* - `model_install_error` fired if the installation has an error
|
||||||
|
*/
|
||||||
|
|
||||||
|
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||||
|
|
||||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallDownloadProgress,
|
actionCreator: socketModelInstallDownloadStarted,
|
||||||
effect: async (action, { dispatch }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { bytes, total_bytes, id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.bytes = bytes;
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
modelImport.total_bytes = total_bytes;
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
modelImport.status = 'downloading';
|
if (modelImport) {
|
||||||
}
|
modelImport.status = 'downloading';
|
||||||
return draft;
|
}
|
||||||
})
|
return draft;
|
||||||
);
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallStarted,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'running';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallDownloadProgress,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { bytes, total_bytes, id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.bytes = bytes;
|
||||||
|
modelImport.total_bytes = total_bytes;
|
||||||
|
modelImport.status = 'downloading';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallComplete,
|
actionCreator: socketModelInstallComplete,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
const { data } = selectModelInstalls(getState());
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
if (modelImport) {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
modelImport.status = 'completed';
|
} else {
|
||||||
}
|
dispatch(
|
||||||
return draft;
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
})
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
);
|
if (modelImport) {
|
||||||
|
modelImport.status = 'completed';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||||
},
|
},
|
||||||
@@ -49,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallError,
|
actionCreator: socketModelInstallError,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id, error, error_type } = action.payload.data;
|
const { id, error, error_type } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.status = 'error';
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
modelImport.error_reason = error_type;
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
modelImport.error = error;
|
if (modelImport) {
|
||||||
}
|
modelImport.status = 'error';
|
||||||
return draft;
|
modelImport.error_reason = error_type;
|
||||||
})
|
modelImport.error = error;
|
||||||
);
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallCancelled,
|
actionCreator: socketModelInstallCancelled,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.status = 'cancelled';
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
}
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
return draft;
|
if (modelImport) {
|
||||||
})
|
modelImport.status = 'cancelled';
|
||||||
);
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallDownloadsComplete,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'downloads_done';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
|||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||||
import type { Templates } from 'features/nodes/store/types';
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||||
@@ -65,9 +65,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
requestAnimationFrame(() => {
|
$needsFit.set(true);
|
||||||
$flow.get()?.fitView();
|
|
||||||
});
|
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (e instanceof WorkflowVersionError) {
|
if (e instanceof WorkflowVersionError) {
|
||||||
// The workflow version was not recognized in the valid list of versions
|
// The workflow version was not recognized in the valid list of versions
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type IAIDndImageProps = FlexProps & {
|
|||||||
draggableData?: TypesafeDraggableData;
|
draggableData?: TypesafeDraggableData;
|
||||||
dropLabel?: ReactNode;
|
dropLabel?: ReactNode;
|
||||||
isSelected?: boolean;
|
isSelected?: boolean;
|
||||||
|
isSelectedForCompare?: boolean;
|
||||||
thumbnail?: boolean;
|
thumbnail?: boolean;
|
||||||
noContentFallback?: ReactElement;
|
noContentFallback?: ReactElement;
|
||||||
useThumbailFallback?: boolean;
|
useThumbailFallback?: boolean;
|
||||||
@@ -61,6 +62,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
draggableData,
|
draggableData,
|
||||||
dropLabel,
|
dropLabel,
|
||||||
isSelected = false,
|
isSelected = false,
|
||||||
|
isSelectedForCompare = false,
|
||||||
thumbnail = false,
|
thumbnail = false,
|
||||||
noContentFallback = defaultNoContentFallback,
|
noContentFallback = defaultNoContentFallback,
|
||||||
uploadElement = defaultUploadElement,
|
uploadElement = defaultUploadElement,
|
||||||
@@ -165,7 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
data-testid={dataTestId}
|
data-testid={dataTestId}
|
||||||
/>
|
/>
|
||||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||||
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
|
<SelectionOverlay
|
||||||
|
isSelected={isSelected}
|
||||||
|
isSelectedForCompare={isSelectedForCompare}
|
||||||
|
isHovered={withHoverOverlay ? isHovered : false}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{!imageDTO && !isUploadDisabled && (
|
{!imageDTO && !isUploadDisabled && (
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
|||||||
pointerEvents={active ? 'auto' : 'none'}
|
pointerEvents={active ? 'auto' : 'none'}
|
||||||
>
|
>
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -3,10 +3,17 @@ import { memo, useMemo } from 'react';
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
|
isSelectedForCompare: boolean;
|
||||||
isHovered: boolean;
|
isHovered: boolean;
|
||||||
};
|
};
|
||||||
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||||
const shadow = useMemo(() => {
|
const shadow = useMemo(() => {
|
||||||
|
if (isSelectedForCompare && isHovered) {
|
||||||
|
return 'hoverSelectedForCompare';
|
||||||
|
}
|
||||||
|
if (isSelectedForCompare && !isHovered) {
|
||||||
|
return 'selectedForCompare';
|
||||||
|
}
|
||||||
if (isSelected && isHovered) {
|
if (isSelected && isHovered) {
|
||||||
return 'hoverSelected';
|
return 'hoverSelected';
|
||||||
}
|
}
|
||||||
@@ -17,7 +24,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
return 'hoverUnselected';
|
return 'hoverUnselected';
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}, [isHovered, isSelected]);
|
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
className="selection-box"
|
className="selection-box"
|
||||||
@@ -27,7 +34,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
bottom={0}
|
bottom={0}
|
||||||
insetInlineStart={0}
|
insetInlineStart={0}
|
||||||
borderRadius="base"
|
borderRadius="base"
|
||||||
opacity={isSelected ? 1 : 0.7}
|
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||||
transitionProperty="common"
|
transitionProperty="common"
|
||||||
transitionDuration="0.1s"
|
transitionDuration="0.1s"
|
||||||
pointerEvents="none"
|
pointerEvents="none"
|
||||||
|
|||||||
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
export const useBoolean = (initialValue: boolean) => {
|
||||||
|
const [isTrue, set] = useState(initialValue);
|
||||||
|
const setTrue = useCallback(() => set(true), []);
|
||||||
|
const setFalse = useCallback(() => set(false), []);
|
||||||
|
const toggle = useCallback(() => set((v) => !v), []);
|
||||||
|
|
||||||
|
const api = useMemo(
|
||||||
|
() => ({
|
||||||
|
isTrue,
|
||||||
|
set,
|
||||||
|
setTrue,
|
||||||
|
setFalse,
|
||||||
|
toggle,
|
||||||
|
}),
|
||||||
|
[isTrue, set, setTrue, setFalse, toggle]
|
||||||
|
);
|
||||||
|
|
||||||
|
return api;
|
||||||
|
};
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
export const stopPropagation = (e: React.MouseEvent) => {
|
export const stopPropagation = (e: React.MouseEvent) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const preventDefault = (e: React.MouseEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
};
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import {
|
|||||||
caLayerControlModeChanged,
|
caLayerControlModeChanged,
|
||||||
caLayerImageChanged,
|
caLayerImageChanged,
|
||||||
caLayerModelChanged,
|
caLayerModelChanged,
|
||||||
|
caLayerProcessedImageChanged,
|
||||||
caLayerProcessorConfigChanged,
|
caLayerProcessorConfigChanged,
|
||||||
caOrIPALayerBeginEndStepPctChanged,
|
caOrIPALayerBeginEndStepPctChanged,
|
||||||
caOrIPALayerWeightChanged,
|
caOrIPALayerWeightChanged,
|
||||||
@@ -84,6 +85,14 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
|
|||||||
[dispatch, layerId]
|
[dispatch, layerId]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onErrorLoadingImage = useCallback(() => {
|
||||||
|
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
|
||||||
|
const onErrorLoadingProcessedImage = useCallback(() => {
|
||||||
|
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
|
||||||
const droppableData = useMemo<CALayerImageDropData>(
|
const droppableData = useMemo<CALayerImageDropData>(
|
||||||
() => ({
|
() => ({
|
||||||
actionType: 'SET_CA_LAYER_IMAGE',
|
actionType: 'SET_CA_LAYER_IMAGE',
|
||||||
@@ -114,6 +123,8 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
|
|||||||
onChangeImage={onChangeImage}
|
onChangeImage={onChangeImage}
|
||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
|
onErrorLoadingImage={onErrorLoadingImage}
|
||||||
|
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ type Props = {
|
|||||||
onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void;
|
onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void;
|
||||||
onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
|
onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
|
||||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||||
|
onErrorLoadingImage: () => void;
|
||||||
|
onErrorLoadingProcessedImage: () => void;
|
||||||
droppableData: TypesafeDroppableData;
|
droppableData: TypesafeDroppableData;
|
||||||
postUploadAction: PostUploadAction;
|
postUploadAction: PostUploadAction;
|
||||||
};
|
};
|
||||||
@@ -41,6 +43,8 @@ export const ControlAdapter = memo(
|
|||||||
onChangeProcessorConfig,
|
onChangeProcessorConfig,
|
||||||
onChangeModel,
|
onChangeModel,
|
||||||
onChangeImage,
|
onChangeImage,
|
||||||
|
onErrorLoadingImage,
|
||||||
|
onErrorLoadingProcessedImage,
|
||||||
droppableData,
|
droppableData,
|
||||||
postUploadAction,
|
postUploadAction,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
@@ -91,6 +95,8 @@ export const ControlAdapter = memo(
|
|||||||
onChangeImage={onChangeImage}
|
onChangeImage={onChangeImage}
|
||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
|
onErrorLoadingImage={onErrorLoadingImage}
|
||||||
|
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|||||||
@@ -27,10 +27,19 @@ type Props = {
|
|||||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||||
droppableData: TypesafeDroppableData;
|
droppableData: TypesafeDroppableData;
|
||||||
postUploadAction: PostUploadAction;
|
postUploadAction: PostUploadAction;
|
||||||
|
onErrorLoadingImage: () => void;
|
||||||
|
onErrorLoadingProcessedImage: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ControlAdapterImagePreview = memo(
|
export const ControlAdapterImagePreview = memo(
|
||||||
({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => {
|
({
|
||||||
|
controlAdapter,
|
||||||
|
onChangeImage,
|
||||||
|
droppableData,
|
||||||
|
postUploadAction,
|
||||||
|
onErrorLoadingImage,
|
||||||
|
onErrorLoadingProcessedImage,
|
||||||
|
}: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
||||||
@@ -128,10 +137,23 @@ export const ControlAdapterImagePreview = memo(
|
|||||||
controlAdapter.processorConfig !== null;
|
controlAdapter.processorConfig !== null;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
|
if (!isConnected) {
|
||||||
handleResetControlImage();
|
return;
|
||||||
}
|
}
|
||||||
}, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]);
|
if (isErrorControlImage) {
|
||||||
|
onErrorLoadingImage();
|
||||||
|
}
|
||||||
|
if (isErrorProcessedControlImage) {
|
||||||
|
onErrorLoadingProcessedImage();
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
handleResetControlImage,
|
||||||
|
isConnected,
|
||||||
|
isErrorControlImage,
|
||||||
|
isErrorProcessedControlImage,
|
||||||
|
onErrorLoadingImage,
|
||||||
|
onErrorLoadingProcessedImage,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@@ -167,6 +189,7 @@ export const ControlAdapterImagePreview = memo(
|
|||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
imageDTO={processedControlImage}
|
imageDTO={processedControlImage}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
|
onError={handleResetControlImage}
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
|
|||||||
@@ -4,20 +4,35 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useMouseEvents } from 'features/controlLayers/hooks/mouseEventHooks';
|
import { BRUSH_SPACING_PCT, MAX_BRUSH_SPACING_PX, MIN_BRUSH_SPACING_PX } from 'features/controlLayers/konva/constants';
|
||||||
|
import { setStageEventHandlers } from 'features/controlLayers/konva/events';
|
||||||
|
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/konva/renderers';
|
||||||
import {
|
import {
|
||||||
|
$brushSize,
|
||||||
|
$brushSpacingPx,
|
||||||
|
$isDrawing,
|
||||||
|
$lastAddedPoint,
|
||||||
$lastCursorPos,
|
$lastCursorPos,
|
||||||
$lastMouseDownPos,
|
$lastMouseDownPos,
|
||||||
|
$selectedLayerId,
|
||||||
|
$selectedLayerType,
|
||||||
|
$shouldInvertBrushSizeScrollDirection,
|
||||||
$tool,
|
$tool,
|
||||||
|
brushSizeChanged,
|
||||||
isRegionalGuidanceLayer,
|
isRegionalGuidanceLayer,
|
||||||
layerBboxChanged,
|
layerBboxChanged,
|
||||||
layerTranslated,
|
layerTranslated,
|
||||||
|
rgLayerLineAdded,
|
||||||
|
rgLayerPointsAdded,
|
||||||
|
rgLayerRectAdded,
|
||||||
selectControlLayersSlice,
|
selectControlLayersSlice,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/util/renderers';
|
import type { AddLineArg, AddPointToLineArg, AddRectArg } from 'features/controlLayers/store/types';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import type { IRect } from 'konva/lib/types';
|
import type { IRect } from 'konva/lib/types';
|
||||||
|
import { clamp } from 'lodash-es';
|
||||||
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
|
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
|
||||||
|
import { getImageDTO } from 'services/api/endpoints/images';
|
||||||
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
|
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
@@ -47,7 +62,6 @@ const useStageRenderer = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const state = useAppSelector((s) => s.controlLayers.present);
|
const state = useAppSelector((s) => s.controlLayers.present);
|
||||||
const tool = useStore($tool);
|
const tool = useStore($tool);
|
||||||
const mouseEventHandlers = useMouseEvents();
|
|
||||||
const lastCursorPos = useStore($lastCursorPos);
|
const lastCursorPos = useStore($lastCursorPos);
|
||||||
const lastMouseDownPos = useStore($lastMouseDownPos);
|
const lastMouseDownPos = useStore($lastMouseDownPos);
|
||||||
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
||||||
@@ -56,6 +70,26 @@ const useStageRenderer = (
|
|||||||
const layerCount = useMemo(() => state.layers.length, [state.layers]);
|
const layerCount = useMemo(() => state.layers.length, [state.layers]);
|
||||||
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
||||||
const dpr = useDevicePixelRatio({ round: false });
|
const dpr = useDevicePixelRatio({ round: false });
|
||||||
|
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||||
|
const brushSpacingPx = useMemo(
|
||||||
|
() => clamp(state.brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
|
||||||
|
[state.brushSize]
|
||||||
|
);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
$brushSize.set(state.brushSize);
|
||||||
|
$brushSpacingPx.set(brushSpacingPx);
|
||||||
|
$selectedLayerId.set(state.selectedLayerId);
|
||||||
|
$selectedLayerType.set(selectedLayerType);
|
||||||
|
$shouldInvertBrushSizeScrollDirection.set(shouldInvertBrushSizeScrollDirection);
|
||||||
|
}, [
|
||||||
|
brushSpacingPx,
|
||||||
|
selectedLayerIdColor,
|
||||||
|
selectedLayerType,
|
||||||
|
shouldInvertBrushSizeScrollDirection,
|
||||||
|
state.brushSize,
|
||||||
|
state.selectedLayerId,
|
||||||
|
]);
|
||||||
|
|
||||||
const onLayerPosChanged = useCallback(
|
const onLayerPosChanged = useCallback(
|
||||||
(layerId: string, x: number, y: number) => {
|
(layerId: string, x: number, y: number) => {
|
||||||
@@ -71,6 +105,31 @@ const useStageRenderer = (
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onRGLayerLineAdded = useCallback(
|
||||||
|
(arg: AddLineArg) => {
|
||||||
|
dispatch(rgLayerLineAdded(arg));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const onRGLayerPointAddedToLine = useCallback(
|
||||||
|
(arg: AddPointToLineArg) => {
|
||||||
|
dispatch(rgLayerPointsAdded(arg));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const onRGLayerRectAdded = useCallback(
|
||||||
|
(arg: AddRectArg) => {
|
||||||
|
dispatch(rgLayerRectAdded(arg));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const onBrushSizeChanged = useCallback(
|
||||||
|
(size: number) => {
|
||||||
|
dispatch(brushSizeChanged(size));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
log.trace('Initializing stage');
|
log.trace('Initializing stage');
|
||||||
if (!container) {
|
if (!container) {
|
||||||
@@ -88,21 +147,29 @@ const useStageRenderer = (
|
|||||||
if (asPreview) {
|
if (asPreview) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
stage.on('mousedown', mouseEventHandlers.onMouseDown);
|
const cleanup = setStageEventHandlers({
|
||||||
stage.on('mouseup', mouseEventHandlers.onMouseUp);
|
stage,
|
||||||
stage.on('mousemove', mouseEventHandlers.onMouseMove);
|
$tool,
|
||||||
stage.on('mouseleave', mouseEventHandlers.onMouseLeave);
|
$isDrawing,
|
||||||
stage.on('wheel', mouseEventHandlers.onMouseWheel);
|
$lastMouseDownPos,
|
||||||
|
$lastCursorPos,
|
||||||
|
$lastAddedPoint,
|
||||||
|
$brushSize,
|
||||||
|
$brushSpacingPx,
|
||||||
|
$selectedLayerId,
|
||||||
|
$selectedLayerType,
|
||||||
|
$shouldInvertBrushSizeScrollDirection,
|
||||||
|
onRGLayerLineAdded,
|
||||||
|
onRGLayerPointAddedToLine,
|
||||||
|
onRGLayerRectAdded,
|
||||||
|
onBrushSizeChanged,
|
||||||
|
});
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
log.trace('Cleaning up stage listeners');
|
log.trace('Removing stage listeners');
|
||||||
stage.off('mousedown', mouseEventHandlers.onMouseDown);
|
cleanup();
|
||||||
stage.off('mouseup', mouseEventHandlers.onMouseUp);
|
|
||||||
stage.off('mousemove', mouseEventHandlers.onMouseMove);
|
|
||||||
stage.off('mouseleave', mouseEventHandlers.onMouseLeave);
|
|
||||||
stage.off('wheel', mouseEventHandlers.onMouseWheel);
|
|
||||||
};
|
};
|
||||||
}, [stage, asPreview, mouseEventHandlers]);
|
}, [asPreview, onBrushSizeChanged, onRGLayerLineAdded, onRGLayerPointAddedToLine, onRGLayerRectAdded, stage]);
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
log.trace('Updating stage dimensions');
|
log.trace('Updating stage dimensions');
|
||||||
@@ -160,7 +227,7 @@ const useStageRenderer = (
|
|||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
log.trace('Rendering layers');
|
log.trace('Rendering layers');
|
||||||
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, getImageDTO, onLayerPosChanged);
|
||||||
}, [
|
}, [
|
||||||
stage,
|
stage,
|
||||||
state.layers,
|
state.layers,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user