mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 08:28:14 -05:00
Compare commits
218 Commits
feat/upsca
...
release/v3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c439e9681e | ||
|
|
369963791e | ||
|
|
0d94bed9f8 | ||
|
|
cec8ad57a5 | ||
|
|
003c2c28c9 | ||
|
|
661b3056ed | ||
|
|
20f7e448c3 | ||
|
|
aa82f9360c | ||
|
|
5aefa49d7d | ||
|
|
b6e9cd4fe2 | ||
|
|
6d1057c560 | ||
|
|
b4790002c7 | ||
|
|
e02700a782 | ||
|
|
83ce8ef1ec | ||
|
|
19e487b5ee | ||
|
|
aa4b56baf2 | ||
|
|
d3a2be69f1 | ||
|
|
02c087ee37 | ||
|
|
cab8d9bb20 | ||
|
|
28e6a7139b | ||
|
|
1625854eaf | ||
|
|
f87b042162 | ||
|
|
183e2c3ee0 | ||
|
|
098d506b95 | ||
|
|
7aa33c352b | ||
|
|
bf62553150 | ||
|
|
2b08d9e53b | ||
|
|
8954953eca | ||
|
|
eb2fcbe28a | ||
|
|
e78b36a9f7 | ||
|
|
144ede031e | ||
|
|
8ca37bba33 | ||
|
|
a608340c89 | ||
|
|
7fecebf7db | ||
|
|
b915d74127 | ||
|
|
6ec347bd41 | ||
|
|
e54843acc9 | ||
|
|
0960518088 | ||
|
|
21de74fac4 | ||
|
|
8ce9b6c51e | ||
|
|
b64ade586d | ||
|
|
3c44a74ba5 | ||
|
|
24d0901d8e | ||
|
|
b1b5f70ea6 | ||
|
|
6392098961 | ||
|
|
2c39aec22d | ||
|
|
d066bc6d19 | ||
|
|
e487bcd0f7 | ||
|
|
e0f8274f49 | ||
|
|
69e3513e90 | ||
|
|
7e706f02cb | ||
|
|
41dad2013a | ||
|
|
3f554d6824 | ||
|
|
202c5a48c6 | ||
|
|
2d71f6f4b8 | ||
|
|
0420874f56 | ||
|
|
f222b871e9 | ||
|
|
8b8d589033 | ||
|
|
f4c895257a | ||
|
|
10af5a26f2 | ||
|
|
1088adeb0a | ||
|
|
ad49380cd1 | ||
|
|
b2fe24c401 | ||
|
|
b128db1d58 | ||
|
|
f7f0630d97 | ||
|
|
5075e9c899 | ||
|
|
3c1549cf5c | ||
|
|
9faa53ceb1 | ||
|
|
32672cfeda | ||
|
|
b5266f89ad | ||
|
|
7a3b467ce0 | ||
|
|
bdfdf854fc | ||
|
|
1c38cce16d | ||
|
|
4cdca45228 | ||
|
|
bfed08673a | ||
|
|
c1aa2b82eb | ||
|
|
0a09f84b07 | ||
|
|
b7938d9ca9 | ||
|
|
977e348a35 | ||
|
|
864f2270c3 | ||
|
|
8b44d83859 | ||
|
|
0b6315de71 | ||
|
|
578e682562 | ||
|
|
92b49e45bb | ||
|
|
b05b8ef677 | ||
|
|
382e2139bd | ||
|
|
d7ebe3f048 | ||
|
|
5c2bdf626b | ||
|
|
390a1c9fbb | ||
|
|
c46d9b8768 | ||
|
|
ef8d9843dd | ||
|
|
dc2e1a42bc | ||
|
|
1869874433 | ||
|
|
94f16b1c69 | ||
|
|
cc0482ae8b | ||
|
|
fdf9833c39 | ||
|
|
5a961bb58e | ||
|
|
627750eded | ||
|
|
2a3909da94 | ||
|
|
e0dddbd38e | ||
|
|
231b7a5000 | ||
|
|
b7773c9962 | ||
|
|
11c501fc80 | ||
|
|
7be5743011 | ||
|
|
c48e648cbb | ||
|
|
29b4ddcc7f | ||
|
|
7ee13879e3 | ||
|
|
ced297ed21 | ||
|
|
3e813ead1f | ||
|
|
820ec08e9a | ||
|
|
4dd289b337 | ||
|
|
b60b1e359e | ||
|
|
208286e97a | ||
|
|
f7b64304ae | ||
|
|
834751e877 | ||
|
|
e7a10d310f | ||
|
|
2ce07a4730 | ||
|
|
45d5ab20ec | ||
|
|
343df03a92 | ||
|
|
b57acb7353 | ||
|
|
7bf7c16a5d | ||
|
|
56340c24c8 | ||
|
|
afe9756667 | ||
|
|
fcea65770f | ||
|
|
16664da5b6 | ||
|
|
c104807201 | ||
|
|
990ce9a1da | ||
|
|
18095ecc44 | ||
|
|
fe19f11abf | ||
|
|
c2f074dc2f | ||
|
|
e02a557454 | ||
|
|
fca60862e2 | ||
|
|
94c186bb4c | ||
|
|
a22c8cb3a1 | ||
|
|
781e8521d5 | ||
|
|
d114d0ba95 | ||
|
|
cc8b7a74da | ||
|
|
388554448a | ||
|
|
cadc0839a6 | ||
|
|
d5160648d0 | ||
|
|
6d0ea42a94 | ||
|
|
2c1100509f | ||
|
|
c34b359c36 | ||
|
|
77d135967f | ||
|
|
ebf26687cb | ||
|
|
1c8991a3df | ||
|
|
3d52656176 | ||
|
|
a2777decd4 | ||
|
|
d219167849 | ||
|
|
090db1ab3a | ||
|
|
468253aa14 | ||
|
|
3ee9a21647 | ||
|
|
0d823901ef | ||
|
|
7ee55489bb | ||
|
|
163ece9aee | ||
|
|
7b2e6deaf1 | ||
|
|
63f94579c5 | ||
|
|
3dfff278aa | ||
|
|
aa7d945b23 | ||
|
|
88db094cf2 | ||
|
|
50a0691514 | ||
|
|
a255624984 | ||
|
|
2630fe3608 | ||
|
|
dee6f86d5e | ||
|
|
6ca6cf713c | ||
|
|
3f7d5b4e0f | ||
|
|
91596d9527 | ||
|
|
d669f0855d | ||
|
|
b2d5b53b5f | ||
|
|
ddc148b70b | ||
|
|
c2d43f007b | ||
|
|
7703bf2ca1 | ||
|
|
23fdf0156f | ||
|
|
cdbf40c9b2 | ||
|
|
46c9dcb113 | ||
|
|
6df79045fa | ||
|
|
d776e0a0a9 | ||
|
|
94ec3da7b5 | ||
|
|
f44496a579 | ||
|
|
99fe95ab03 | ||
|
|
95ecb1a0c1 | ||
|
|
bd15874cf6 | ||
|
|
30ab81b6bb | ||
|
|
78195491bc | ||
|
|
c63390f6e1 | ||
|
|
cbd451c610 | ||
|
|
b0f91f2e75 | ||
|
|
3ac68cde66 | ||
|
|
a69b1cd598 | ||
|
|
65a76a086b | ||
|
|
07381e5a26 | ||
|
|
6bb378a101 | ||
|
|
7df67d077a | ||
|
|
b761807219 | ||
|
|
fb1b03960e | ||
|
|
74bfb5e1f9 | ||
|
|
bc1bce18b0 | ||
|
|
942ecbbde4 | ||
|
|
79db0e9e93 | ||
|
|
0c17f8604f | ||
|
|
054edc4077 | ||
|
|
5a9993772d | ||
|
|
f2cd9e9ae2 | ||
|
|
9f86cfa471 | ||
|
|
8c1390166f | ||
|
|
1ad98ce999 | ||
|
|
5f4a62810e | ||
|
|
35b7ae90ae | ||
|
|
9ed4d487d2 | ||
|
|
69d37217b8 | ||
|
|
7afdefb0e5 | ||
|
|
dff466244d | ||
|
|
f5d95ffed5 | ||
|
|
6f9c1c6d4e | ||
|
|
811c82a677 | ||
|
|
4f0e43ec1b | ||
|
|
26a7b7b66d | ||
|
|
8611ffe32d |
@@ -159,7 +159,7 @@ groups in `invokeia.yaml`:
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
|
||||
|
||||
336
docs/features/UTILITIES.md
Normal file
336
docs/features/UTILITIES.md
Normal file
@@ -0,0 +1,336 @@
|
||||
---
|
||||
title: Command-line Utilities
|
||||
---
|
||||
|
||||
# :material-file-document: Utilities
|
||||
|
||||
# Command-line Utilities
|
||||
|
||||
InvokeAI comes with several scripts that are accessible via the
|
||||
command line. To access these commands, start the "developer's
|
||||
console" from the launcher (`invoke.bat` menu item [8]). Users who are
|
||||
familiar with Python can alternatively activate InvokeAI's virtual
|
||||
environment (typically, but not necessarily `invokeai/.venv`).
|
||||
|
||||
In the developer's console, type the script's name to run it. To get a
|
||||
synopsis of what a utility does and the command-line arguments it
|
||||
accepts, pass it the `-h` argument, e.g.
|
||||
|
||||
```bash
|
||||
invokeai-merge -h
|
||||
```
|
||||
## **invokeai-web**
|
||||
|
||||
This script launches the web server and is effectively identical to
|
||||
selecting option [1] in the launcher. An advantage of launching the
|
||||
server from the command line is that you can override any setting
|
||||
configuration option in `invokeai.yaml` using like-named command-line
|
||||
arguments. For example, to temporarily change the size of the RAM
|
||||
cache to 7 GB, you can launch as follows:
|
||||
|
||||
```bash
|
||||
invokeai-web --ram 7
|
||||
```
|
||||
|
||||
## **invokeai-merge**
|
||||
|
||||
This is the model merge script, the same as launcher option [4]. Call
|
||||
it with the `--gui` command-line argument to start the interactive
|
||||
console-based GUI. Alternatively, you can run it non-interactively
|
||||
using command-line arguments as illustrated in the example below which
|
||||
merges models named `stable-diffusion-1.5` and `inkdiffusion` into a new model named
|
||||
`my_new_model`:
|
||||
|
||||
```bash
|
||||
invokeai-merge --force --base-model sd-1 --models stable-diffusion-1.5 inkdiffusion --merged_model_name my_new_model
|
||||
```
|
||||
|
||||
## **invokeai-ti**
|
||||
|
||||
This is the textual inversion training script that is run by launcher
|
||||
option [3]. Call it with `--gui` to run the interactive console-based
|
||||
front end. It can also be run non-interactively. It has about a
|
||||
zillion arguments, but a typical training session can be launched
|
||||
with:
|
||||
|
||||
```bash
|
||||
invokeai-ti --model stable-diffusion-1.5 \
|
||||
--placeholder_token 'jello' \
|
||||
--learnable_property object \
|
||||
--num_train_epochs 50 \
|
||||
--train_data_dir /path/to/training/images \
|
||||
--output_dir /path/to/trained/model
|
||||
```
|
||||
|
||||
(Note that \\ is the Linux/Mac long-line continuation character. Use ^
|
||||
in Windows).
|
||||
|
||||
## **invokeai-install**
|
||||
|
||||
This is the console-based model install script that is run by launcher
|
||||
option [5]. If called without arguments, it will launch the
|
||||
interactive console-based interface. It can also be used
|
||||
non-interactively to list, add and remove models as shown by these
|
||||
examples:
|
||||
|
||||
* This will download and install three models from CivitAI, HuggingFace,
|
||||
and local disk:
|
||||
|
||||
```bash
|
||||
invokeai-install --add https://civitai.com/api/download/models/161302 ^
|
||||
gsdf/Counterfeit-V3.0 ^
|
||||
D:\Models\merge_model_two.safetensors
|
||||
```
|
||||
(Note that ^ is the Windows long-line continuation character. Use \\ on
|
||||
Linux/Mac).
|
||||
|
||||
* This will list installed models of type `main`:
|
||||
|
||||
```bash
|
||||
invokeai-model-install --list-models main
|
||||
```
|
||||
|
||||
* This will delete the models named `voxel-ish` and `realisticVision`:
|
||||
|
||||
```bash
|
||||
invokeai-model-install --delete voxel-ish realisticVision
|
||||
```
|
||||
|
||||
## **invokeai-configure**
|
||||
|
||||
This is the console-based configure script that ran when InvokeAI was
|
||||
first installed. You can run it again at any time to change the
|
||||
configuration, repair a broken install.
|
||||
|
||||
Called without any arguments, `invokeai-configure` enters interactive
|
||||
mode with two screens. The first screen is a form that provides access
|
||||
to most of InvokeAI's configuration options. The second screen lets
|
||||
you download, add, and delete models interactively. When you exit the
|
||||
second screen, the script will add any missing "support models"
|
||||
needed for core functionality, and any selected "sd weights" which are
|
||||
the model checkpoint/diffusers files.
|
||||
|
||||
This behavior can be changed via a series of command-line
|
||||
arguments. Here are some of the useful ones:
|
||||
|
||||
* `invokeai-configure --skip-sd-weights --skip-support-models`
|
||||
This will run just the configuration part of the utility, skipping
|
||||
downloading of support models and stable diffusion weights.
|
||||
|
||||
* `invokeai-configure --yes`
|
||||
This will run the configure script non-interactively. It will set the
|
||||
configuration options to their default values, install/repair support
|
||||
models, and download the "recommended" set of SD models.
|
||||
|
||||
* `invokeai-configure --yes --default_only`
|
||||
This will run the configure script non-interactively. In contrast to
|
||||
the previous command, it will only download the default SD model,
|
||||
Stable Diffusion v1.5
|
||||
|
||||
* `invokeai-configure --yes --default_only --skip-sd-weights`
|
||||
This is similar to the previous command, but will not download any
|
||||
SD models at all. It is usually used to repair a broken install.
|
||||
|
||||
By default, `invokeai-configure` runs on the currently active InvokeAI
|
||||
root folder. To run it against a different root, pass it the `--root
|
||||
</path/to/root>` argument.
|
||||
|
||||
Lastly, you can use `invokeai-configure` to create a working root
|
||||
directory entirely from scratch. Assuming you wish to make a root directory
|
||||
named `InvokeAI-New`, run this command:
|
||||
|
||||
```bash
|
||||
invokeai-configure --root InvokeAI-New --yes --default_only
|
||||
```
|
||||
This will create a minimally functional root directory. You can now
|
||||
launch the web server against it with `invokeai-web --root InvokeAI-New`.
|
||||
|
||||
## **invokeai-update**
|
||||
|
||||
This is the interactive console-based script that is run by launcher
|
||||
menu item [9] to update to a new version of InvokeAI. It takes no
|
||||
command-line arguments.
|
||||
|
||||
## **invokeai-metadata**
|
||||
|
||||
This is a script which takes a list of InvokeAI-generated images and
|
||||
outputs their metadata in the same JSON format that you get from the
|
||||
`</>` button in the Web GUI. For example:
|
||||
|
||||
```bash
|
||||
$ invokeai-metadata ffe2a115-b492-493c-afff-7679aa034b50.png
|
||||
ffe2a115-b492-493c-afff-7679aa034b50.png:
|
||||
{
|
||||
"app_version": "3.1.0",
|
||||
"cfg_scale": 8.0,
|
||||
"clip_skip": 0,
|
||||
"controlnets": [],
|
||||
"generation_mode": "sdxl_txt2img",
|
||||
"height": 1024,
|
||||
"loras": [],
|
||||
"model": {
|
||||
"base_model": "sdxl",
|
||||
"model_name": "stable-diffusion-xl-base-1.0",
|
||||
"model_type": "main"
|
||||
},
|
||||
"negative_prompt": "",
|
||||
"negative_style_prompt": "",
|
||||
"positive_prompt": "military grade sushi dinner for shock troopers",
|
||||
"positive_style_prompt": "",
|
||||
"rand_device": "cpu",
|
||||
"refiner_cfg_scale": 7.5,
|
||||
"refiner_model": {
|
||||
"base_model": "sdxl-refiner",
|
||||
"model_name": "sd_xl_refiner_1.0",
|
||||
"model_type": "main"
|
||||
},
|
||||
"refiner_negative_aesthetic_score": 2.5,
|
||||
"refiner_positive_aesthetic_score": 6.0,
|
||||
"refiner_scheduler": "euler",
|
||||
"refiner_start": 0.8,
|
||||
"refiner_steps": 20,
|
||||
"scheduler": "euler",
|
||||
"seed": 387129902,
|
||||
"steps": 25,
|
||||
"width": 1024
|
||||
}
|
||||
```
|
||||
|
||||
You may list multiple files on the command line.
|
||||
|
||||
## **invokeai-import-images**
|
||||
|
||||
InvokeAI uses a database to store information about images it
|
||||
generated, and just copying the image files from one InvokeAI root
|
||||
directory to another does not automatically import those images into
|
||||
the destination's gallery. This script allows you to bulk import
|
||||
images generated by one instance of InvokeAI into a gallery maintained
|
||||
by another. It also works on images generated by older versions of
|
||||
InvokeAI, going way back to version 1.
|
||||
|
||||
This script has an interactive mode only. The following example shows
|
||||
it in action:
|
||||
|
||||
```bash
|
||||
$ invokeai-import-images
|
||||
===============================================================================
|
||||
This script will import images generated by earlier versions of
|
||||
InvokeAI into the currently installed root directory:
|
||||
/home/XXXX/invokeai-main
|
||||
If this is not what you want to do, type ctrl-C now to cancel.
|
||||
===============================================================================
|
||||
= Configuration & Settings
|
||||
Found invokeai.yaml file at /home/XXXX/invokeai-main/invokeai.yaml:
|
||||
Database : /home/XXXX/invokeai-main/databases/invokeai.db
|
||||
Outputs : /home/XXXX/invokeai-main/outputs/images
|
||||
|
||||
Use these paths for import (yes) or choose different ones (no) [Yn]:
|
||||
Inputs: Specify absolute path containing InvokeAI .png images to import: /home/XXXX/invokeai-2.3/outputs/images/
|
||||
Include files from subfolders recursively [yN]?
|
||||
|
||||
Options for board selection for imported images:
|
||||
1) Select an existing board name. (found 4)
|
||||
2) Specify a board name to create/add to.
|
||||
3) Create/add to board named 'IMPORT'.
|
||||
4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_20230919T203519Z).
|
||||
5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5).
|
||||
Specify desired board option: 3
|
||||
|
||||
===============================================================================
|
||||
= Import Settings Confirmation
|
||||
|
||||
Database File Path : /home/XXXX/invokeai-main/databases/invokeai.db
|
||||
Outputs/Images Directory : /home/XXXX/invokeai-main/outputs/images
|
||||
Import Image Source Directory : /home/XXXX/invokeai-2.3/outputs/images/
|
||||
Recurse Source SubDirectories : No
|
||||
Count of .png file(s) found : 5785
|
||||
Board name option specified : IMPORT
|
||||
Database backup will be taken at : /home/XXXX/invokeai-main/databases/backup
|
||||
|
||||
Notes about the import process:
|
||||
- Source image files will not be modified, only copied to the outputs directory.
|
||||
- If the same file name already exists in the destination, the file will be skipped.
|
||||
- If the same file name already has a record in the database, the file will be skipped.
|
||||
- Invoke AI metadata tags will be updated/written into the imported copy only.
|
||||
- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)
|
||||
- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer.
|
||||
- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder.
|
||||
|
||||
Do you wish to continue with the import [Yn] ?
|
||||
|
||||
Making DB Backup at /home/lstein/invokeai-main/databases/backup/backup-20230919T203519Z-invokeai.db...Done!
|
||||
|
||||
===============================================================================
|
||||
Importing /home/XXXX/invokeai-2.3/outputs/images/17d09907-297d-4db3-a18a-60b337feac66.png
|
||||
... (5785 more lines) ...
|
||||
===============================================================================
|
||||
= Import Complete - Elpased Time: 0.28 second(s)
|
||||
|
||||
Source File(s) : 5785
|
||||
Total Imported : 5783
|
||||
Skipped b/c file already exists on disk : 1
|
||||
Skipped b/c file already exists in db : 0
|
||||
Errors during import : 1
|
||||
```
|
||||
## **invokeai-db-maintenance**
|
||||
|
||||
This script helps maintain the integrity of your InvokeAI database by
|
||||
finding and fixing three problems that can arise over time:
|
||||
|
||||
1. An image was manually deleted from the outputs directory, leaving a
|
||||
dangling image record in the InvokeAI database. This will cause a
|
||||
black image to appear in the gallery. This is an "orphaned database
|
||||
image record." The script can fix this by running a "clean"
|
||||
operation on the database, removing the orphaned entries.
|
||||
|
||||
2. An image is present in the outputs directory but there is no
|
||||
corresponding entry in the database. This can happen when the image
|
||||
is added manually to the outputs directory, or if a crash occurred
|
||||
after the image was generated but before the database was
|
||||
completely updated. The symptom is that the image is present in the
|
||||
outputs folder but doesn't appear in the InvokeAI gallery. This is
|
||||
called an "orphaned image file." The script can fix this problem by
|
||||
running an "archive" operation in which orphaned files are moved
|
||||
into a directory named `outputs/images-archive`. If you wish, you
|
||||
can then run `invokeai-image-import` to reimport these images back
|
||||
into the database.
|
||||
|
||||
3. The thumbnail for an image is missing, again causing a black
|
||||
gallery thumbnail. This is fixed by running the "thumbnaiils"
|
||||
operation, which simply regenerates and re-registers the missing
|
||||
thumbnail.
|
||||
|
||||
You can find and fix all three of these problems in a single go by
|
||||
executing this command:
|
||||
|
||||
```bash
|
||||
invokeai-db-maintenance --operation all
|
||||
```
|
||||
|
||||
Or you can run just the clean and thumbnail operations like this:
|
||||
|
||||
```bash
|
||||
invokeai-db-maintenance -operation clean, thumbnail
|
||||
```
|
||||
|
||||
If called without any arguments, the script will ask you which
|
||||
operations you wish to perform.
|
||||
|
||||
## **invokeai-migrate3**
|
||||
|
||||
This script will migrate settings and models (but not images!) from an
|
||||
InvokeAI v2.3 root folder to an InvokeAI 3.X folder. Call it with the
|
||||
source and destination root folders like this:
|
||||
|
||||
```bash
|
||||
invokeai-migrate3 --from ~/invokeai-2.3 --to invokeai-3.1.1
|
||||
```
|
||||
|
||||
Both directories must previously have been properly created and
|
||||
initialized by `invokeai-configure`. If you wish to migrate the images
|
||||
contained in the older root as well, you can use the
|
||||
`invokeai-image-migrate` script described earlier.
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2023, Lincoln Stein and the InvokeAI Development Team
|
||||
@@ -51,6 +51,9 @@ Prevent InvokeAI from displaying unwanted racy images.
|
||||
### * [Controlling Logging](LOGGING.md)
|
||||
Control how InvokeAI logs status messages.
|
||||
|
||||
### * [Command-line Utilities](UTILITIES.md)
|
||||
A list of the command-line utilities available with InvokeAI.
|
||||
|
||||
<!-- OUT OF DATE
|
||||
### * [Miscellaneous](OTHER.md)
|
||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||
|
||||
@@ -147,6 +147,7 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
### InvokeAI Configuration
|
||||
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
|
||||
- [Database Maintenance and other Command Line Utilities](features/UTILITIES.md)
|
||||
|
||||
## :octicons-log-16: Important Changes Since Version 2.3
|
||||
|
||||
|
||||
@@ -196,6 +196,40 @@ Results after using the depth controlnet
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Prompt Tools
|
||||
|
||||
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These where written to accompany the PromptsFromFile node and other prompt generation nodes.
|
||||
|
||||
1. PromptJoin - Joins to prompts into one.
|
||||
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
|
||||
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
|
||||
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
||||
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
||||
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
|
||||
7. PromptJoinThree - Joins 3 prompt together.
|
||||
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
|
||||
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
|
||||
|
||||
--------------------------------
|
||||
|
||||
### XY Image to Grid and Images to Grids nodes
|
||||
|
||||
**Description:** Image to grid nodes and supporting tools.
|
||||
|
||||
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then mutilple grids will be created until it runs out of images.
|
||||
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporoting nodes. See example node setups for more details.
|
||||
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Example Node Template
|
||||
|
||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||
|
||||
@@ -17,9 +17,10 @@ echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo 10. Run the InvokeAI image database maintenance script
|
||||
echo 11. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
set /P choice="Please enter 1-11, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
@@ -58,8 +59,11 @@ IF /I "%choice%" == "1" (
|
||||
echo Running invokeai-update...
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
) ELSE IF /I "%choice%" == "11" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
|
||||
@@ -97,13 +97,13 @@ do_choice() {
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
printf "Running the db maintenance script\n"
|
||||
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
"HELP 1")
|
||||
11)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
@@ -125,7 +125,10 @@ do_dialog() {
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
9 "Update InvokeAI"
|
||||
10 "Run the InvokeAI image database maintenance script"
|
||||
11 "Command-line help"
|
||||
)
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
@@ -157,9 +160,10 @@ do_line_input() {
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "10: Run the InvokeAI image database maintenance script\n"
|
||||
printf "11: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
read -p "Please enter 1-11, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||
@@ -9,7 +10,10 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@@ -25,6 +29,7 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.thread import lock
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -63,22 +68,32 @@ class ApiDependencies:
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
logger.info(f"Using database at {db_location}")
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
|
||||
if config.log_sql:
|
||||
db_conn.set_trace_callback(print)
|
||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -120,18 +135,29 @@ class ApiDependencies:
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
try:
|
||||
lock.acquire()
|
||||
db_conn.execute("VACUUM;")
|
||||
db_conn.commit()
|
||||
logger.info("Cleaned database")
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
@@ -103,3 +104,43 @@ async def set_log_level(
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/invocation_cache",
|
||||
operation_id="clear_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def clear_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.clear()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/enable",
|
||||
operation_id="enable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def enable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.enable()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/disable",
|
||||
operation_id="disable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def disable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.disable()
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/invocation_cache/status",
|
||||
operation_id="get_invocation_cache_status",
|
||||
responses={200: {"model": InvocationCacheStatus}},
|
||||
)
|
||||
async def get_invocation_cache_status() -> InvocationCacheStatus:
|
||||
"""Clears the invocation cache"""
|
||||
return ApiDependencies.invoker.services.invocation_cache.get_status()
|
||||
|
||||
247
invokeai/app/api/routers/session_queue.py
Normal file
247
invokeai/app/api/routers/session_queue.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
|
||||
|
||||
class SessionQueueAndProcessorStatus(BaseModel):
|
||||
"""The overall status of session queue and processor"""
|
||||
|
||||
queue: SessionQueueStatus
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_graph",
|
||||
operation_id="enqueue_graph",
|
||||
responses={
|
||||
201: {"model": EnqueueGraphResult},
|
||||
},
|
||||
)
|
||||
async def enqueue_graph(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
graph: Graph = Body(description="The graph to enqueue"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueGraphResult:
|
||||
"""Enqueues a graph for single execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
responses={
|
||||
201: {"model": EnqueueBatchResult},
|
||||
},
|
||||
)
|
||||
async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list",
|
||||
operation_id="list_queue_items",
|
||||
responses={
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_queue_items(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
limit: int = Query(default=50, description="The number of items to fetch"),
|
||||
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets all queue items (without graphs)"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/processor/resume",
|
||||
operation_id="resume",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/processor/pause",
|
||||
operation_id="pause",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_batch_ids",
|
||||
operation_id="cancel_by_batch_ids",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_batch_ids(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
responses={
|
||||
200: {"model": ClearResult},
|
||||
},
|
||||
)
|
||||
async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/prune",
|
||||
operation_id="prune",
|
||||
responses={
|
||||
200: {"model": PruneResult},
|
||||
},
|
||||
)
|
||||
async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/current",
|
||||
operation_id="get_current_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/next",
|
||||
operation_id="get_next_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/status",
|
||||
operation_id="get_queue_status",
|
||||
responses={
|
||||
200: {"model": SessionQueueAndProcessorStatus},
|
||||
},
|
||||
)
|
||||
async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/b/{batch_id}/status",
|
||||
operation_id="get_batch_status",
|
||||
responses={
|
||||
200: {"model": BatchStatus},
|
||||
},
|
||||
)
|
||||
async def get_batch_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/i/{item_id}",
|
||||
operation_id="get_queue_item",
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
)
|
||||
async def get_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/i/{item_id}/cancel",
|
||||
operation_id="cancel_queue_item",
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
)
|
||||
async def cancel_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
@@ -23,12 +23,14 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
return session
|
||||
|
||||
|
||||
@@ -36,6 +38,7 @@ async def create_session(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
@@ -57,6 +60,7 @@ async def list_sessions(
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
@@ -77,6 +81,7 @@ async def get_session(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -109,6 +114,7 @@ async def add_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -142,6 +148,7 @@ async def update_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -172,6 +179,7 @@ async def delete_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -203,6 +211,7 @@ async def add_edge(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -241,8 +250,10 @@ async def delete_edge(
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def invoke_session(
|
||||
queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
@@ -254,7 +265,7 @@ async def invoke_session(
|
||||
if session.is_complete():
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@@ -262,6 +273,7 @@ async def invoke_session(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
|
||||
41
invokeai/app/api/routers/utilities.py
Normal file
41
invokeai/app/api/routers/utilities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pyparsing import ParseException
|
||||
|
||||
utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])
|
||||
|
||||
|
||||
class DynamicPromptsResponse(BaseModel):
|
||||
prompts: list[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@utilities_router.post(
|
||||
"/dynamicprompts",
|
||||
operation_id="parse_dynamicprompts",
|
||||
responses={
|
||||
200: {"model": DynamicPromptsResponse},
|
||||
},
|
||||
)
|
||||
async def parse_dynamicprompts(
|
||||
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
|
||||
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
try:
|
||||
error: Optional[str] = None
|
||||
if combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(prompt, max_prompts=max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(prompt, num_images=max_prompts)
|
||||
except ParseException as e:
|
||||
prompts = [prompt]
|
||||
error = str(e)
|
||||
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)
|
||||
@@ -3,34 +3,35 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
@@ -33,7 +32,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import app_info, board_images, boards, images, models, sessions
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
|
||||
@@ -92,6 +91,8 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
@@ -102,6 +103,8 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
@@ -12,6 +14,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import argparse
|
||||
import re
|
||||
import shlex
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union, get_type_hints
|
||||
@@ -249,19 +252,18 @@ def invoke_cli():
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -303,12 +305,13 @@ def invoke_cli():
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
|
||||
@@ -67,6 +67,7 @@ class FieldDescriptions:
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
@@ -155,6 +156,7 @@ class UIType(str, Enum):
|
||||
VaeModel = "VaeModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
@@ -417,12 +419,27 @@ class UIConfigBase(BaseModel):
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
@@ -520,6 +537,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
validate_all = True
|
||||
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
@@ -568,7 +588,29 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = context.services.invocation_cache.create_key(self)
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
context.services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.__fields__["type"].default
|
||||
|
||||
id: str = Field(
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||
@@ -581,6 +623,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
def validate_workflow_is_json(cls, v):
|
||||
@@ -604,6 +647,7 @@ def invocation(
|
||||
tags: Optional[list[str]] = None,
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
@@ -636,6 +680,8 @@ def invocation(
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
if use_cache is not None:
|
||||
cls.__fields__["use_cache"].default = use_cache
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
|
||||
@@ -56,6 +56,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
@@ -7,14 +7,14 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -99,14 +99,15 @@ class CompelInvocation(BaseInvocation):
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, self.clip.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
with (
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -122,7 +123,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
@@ -213,14 +214,15 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
with (
|
||||
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -244,7 +246,7 @@ class SDXLPromptInvocationBase:
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
@@ -436,9 +438,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
(
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
|
||||
@@ -38,7 +38,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -100,7 +99,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
default=1.0, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
|
||||
@@ -965,3 +965,42 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"save_image",
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
103
invokeai/app/invocations/ip_adapter.py
Normal file
103
invokeai/app/invocations/ip_adapter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
|
||||
|
||||
class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("ip_adapter_output")
|
||||
class IPAdapterOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
)
|
||||
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.model_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
)
|
||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
|
||||
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
|
||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||
# disk vs. downloading the model.
|
||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
||||
)
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = CLIPVisionModelField(
|
||||
model_name=image_encoder_model_name,
|
||||
base_model=BaseModelType.Any,
|
||||
)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=image_encoder_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -19,6 +22,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
@@ -31,15 +35,17 @@ from invokeai.app.invocations.primitives import (
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
@@ -68,7 +74,6 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@@ -191,7 +196,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@@ -205,7 +210,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
@@ -215,13 +220,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
ip_adapter: Optional[IPAdapterField] = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@@ -323,8 +330,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
# really only need model for dtype and device
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
@@ -344,57 +349,107 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
control_list = None
|
||||
if control_list is None:
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
return None
|
||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, # model object
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
controlnet_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
|
||||
return controlnet_data
|
||||
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[IPAdapterField],
|
||||
conditioning_data: ConditioningData,
|
||||
unet: UNet2DConditionModel,
|
||||
exit_stack: ExitStack,
|
||||
) -> Optional[IPAdapterData]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
"""
|
||||
if ip_adapter is None:
|
||||
return None
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=ip_adapter.ip_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
return IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=ip_adapter.weight,
|
||||
begin_step_percent=ip_adapter.begin_step_percent,
|
||||
end_step_percent=ip_adapter.end_step_percent,
|
||||
)
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@@ -488,9 +543,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
unet_info as unet,
|
||||
):
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -509,8 +567,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=latents.shape,
|
||||
@@ -519,6 +576,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
unet=unet,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
@@ -537,7 +602,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masked_latents=masked_latents,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
control_data=controlnet_data, # list[ControlNetData],
|
||||
ip_adapter_data=ip_adapter_data, # IPAdapterData,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
@@ -792,8 +858,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
@@ -820,6 +885,18 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
|
||||
@@ -54,7 +54,14 @@ class DivideInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"rand_int",
|
||||
title="Random Integer",
|
||||
tags=["math", "random"],
|
||||
category="math",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
@@ -116,7 +117,8 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
|
||||
@@ -95,9 +95,10 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
||||
orig_tokenizer, text_encoder, ti_list
|
||||
) as (tokenizer, ti_manager):
|
||||
with (
|
||||
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
||||
):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
@@ -165,7 +166,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
@@ -178,7 +178,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@@ -10,7 +10,14 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
|
||||
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||
@invocation(
|
||||
"dynamic_prompt",
|
||||
title="Dynamic Prompt",
|
||||
tags=["prompt", "collection"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
|
||||
@@ -53,24 +53,20 @@ class BoardImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -8,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
@@ -87,24 +87,20 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
@@ -174,7 +170,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
board_id = uuid_string()
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
|
||||
@@ -16,7 +16,7 @@ import pydoc
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic import BaseSettings
|
||||
@@ -39,10 +39,10 @@ class InvokeAISettings(BaseSettings):
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
@@ -83,7 +83,8 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
@@ -116,8 +117,8 @@ class InvokeAISettings(BaseSettings):
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
def cmd_name(cls, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
@@ -133,16 +134,12 @@ class InvokeAISettings(BaseSettings):
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
def _excluded(cls) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
|
||||
@@ -194,8 +194,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict]] = None
|
||||
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
@@ -234,6 +234,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
||||
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||
|
||||
@@ -245,18 +246,23 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||
|
||||
# NODES
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
@@ -272,7 +278,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
@@ -283,12 +289,16 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
loaded_conf = None
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
if isinstance(loaded_conf, DictConfig):
|
||||
InvokeAISettings.initconf = loaded_conf
|
||||
else:
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
@@ -376,13 +386,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self) -> Path:
|
||||
"""
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
"""
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
@@ -405,11 +408,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> float:
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
|
||||
@@ -10,57 +10,58 @@ default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
)
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
),
|
||||
graph=graph,
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
|
||||
@@ -4,21 +4,23 @@ from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
queue_event: str = "queue_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
@@ -26,6 +28,9 @@ class EventServiceBase:
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -35,11 +40,14 @@ class EventServiceBase:
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
node_id=node.get("id"),
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
@@ -50,15 +58,21 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -68,6 +82,9 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -75,9 +92,12 @@ class EventServiceBase:
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -86,28 +106,47 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
@@ -115,9 +154,12 @@ class EventServiceBase:
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -128,6 +170,9 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
@@ -136,9 +181,12 @@ class EventServiceBase:
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -152,14 +200,20 @@ class EventServiceBase:
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
@@ -168,18 +222,78 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload=dict(
|
||||
queue_id=session_queue_item.queue_id,
|
||||
queue_item_id=session_queue_item.item_id,
|
||||
status=session_queue_item.status,
|
||||
batch_id=session_queue_item.batch_id,
|
||||
session_id=session_queue_item.session_id,
|
||||
error=session_queue_item.error,
|
||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload=dict(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload=dict(queue_id=queue_id),
|
||||
)
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
from ..invocations.baseinvocation import (
|
||||
@@ -137,19 +138,31 @@ def are_connections_compatible(
|
||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||
|
||||
|
||||
class NodeAlreadyInGraphError(Exception):
|
||||
class NodeAlreadyInGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEdgeError(Exception):
|
||||
class InvalidEdgeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
class NodeNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeAlreadyExecutedError(Exception):
|
||||
class NodeAlreadyExecutedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateNodeIdError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeFieldNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeIdMismatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -227,7 +240,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -237,6 +250,59 @@ class Graph(BaseModel):
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@root_validator
|
||||
def validate_nodes_and_edges(cls, values):
|
||||
"""Validates that all edges match nodes in the graph"""
|
||||
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
|
||||
edges = cast(Optional[list[Edge]], values.get("edges"))
|
||||
|
||||
if nodes is not None:
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
# Validate that all node ids match the keys in the nodes dict
|
||||
for k, v in nodes.items():
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
if edges is not None and nodes is not None:
|
||||
# Validate that all edges match nodes in the graph
|
||||
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
|
||||
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
|
||||
if missing_node_ids:
|
||||
raise NodeNotFoundError(
|
||||
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
|
||||
)
|
||||
|
||||
# Validate that all edge fields match node fields in the graph
|
||||
for edge in edges:
|
||||
source_node = nodes.get(edge.source.node_id, None)
|
||||
if source_node is None:
|
||||
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||
|
||||
destination_node = nodes.get(edge.destination.node_id, None)
|
||||
if destination_node is None:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
|
||||
)
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@@ -697,8 +763,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
@@ -847,7 +912,7 @@ class GraphExecutionState(BaseModel):
|
||||
new_node = copy.deepcopy(node)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = str(uuid.uuid4())
|
||||
new_node.id = uuid_string()
|
||||
|
||||
# Set the iteration index for iteration invocations
|
||||
if isinstance(new_node, IterateInvocation):
|
||||
@@ -1082,7 +1147,7 @@ class ExposedNodeOutput(BaseModel):
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
|
||||
@@ -148,24 +148,20 @@ class ImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
@@ -38,6 +38,29 @@ if TYPE_CHECKING:
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||
"""Register a callback for when an image is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an image is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: ImageDTO) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
@@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
super().__init__()
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
@@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
@@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
image_dto = self.get_dto(image_name)
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
@@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
|
||||
try:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
@@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
for image_name in image_names:
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
@@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
|
||||
count = len(image_names)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
|
||||
0
invokeai/app/services/invocation_cache/__init__.py
Normal file
0
invokeai/app/services/invocation_cache/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
|
||||
|
||||
class InvocationCacheBase(ABC):
|
||||
"""
|
||||
Base class for invocation caches.
|
||||
When an invocation is executed, it is hashed and its output stored in the cache.
|
||||
When new invocations are executed, if they are flagged with `use_cache`, they
|
||||
will attempt to pull their value from the cache before executing.
|
||||
|
||||
Implementations should register for the `on_deleted` event of the `images` and `latents`
|
||||
services, and delete any cached outputs that reference the deleted image or latent.
|
||||
|
||||
See the memory implementation for an example.
|
||||
|
||||
Implementations should respect the `node_cache_size` configuration value, and skip all
|
||||
cache logic if the value is set to 0.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
"""Retrieves an invocation output from the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
"""Stores an invocation output in the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
"""Deletes an invocation output from the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable(self) -> None:
|
||||
"""Disables the cache, overriding the max cache size"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enable(self) -> None:
|
||||
"""Enables the cache, letting the the max cache size take effect"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
"""Returns the status of the cache"""
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationCacheStatus(BaseModel):
|
||||
size: int = Field(description="The current size of the invocation cache")
|
||||
hits: int = Field(description="The number of cache hits")
|
||||
misses: int = Field(description="The number of cache misses")
|
||||
enabled: bool = Field(description="Whether the invocation cache is enabled")
|
||||
max_size: int = Field(description="The maximum size of the invocation cache")
|
||||
@@ -0,0 +1,111 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class MemoryInvocationCache(InvocationCacheBase):
|
||||
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
||||
__max_cache_size: int
|
||||
__disabled: bool
|
||||
__hits: int
|
||||
__misses: int
|
||||
__cache_ids: Queue
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, max_cache_size: int = 0) -> None:
|
||||
self.__cache = dict()
|
||||
self.__max_cache_size = max_cache_size
|
||||
self.__disabled = False
|
||||
self.__hits = 0
|
||||
self.__misses = 0
|
||||
self.__cache_ids = Queue()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
if self.__max_cache_size == 0:
|
||||
return
|
||||
self.__invoker.services.images.on_deleted(self._delete_by_match)
|
||||
self.__invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
item = self.__cache.get(key, None)
|
||||
if item is not None:
|
||||
self.__hits += 1
|
||||
return item[0]
|
||||
self.__misses += 1
|
||||
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
if key not in self.__cache:
|
||||
self.__cache[key] = (invocation_output, invocation_output.json())
|
||||
self.__cache_ids.put(key)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
try:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
except KeyError:
|
||||
# this means the cache_ids are somehow out of sync w/ the cache
|
||||
pass
|
||||
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
if key in self.__cache:
|
||||
del self.__cache[key]
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
self.__cache.clear()
|
||||
self.__cache_ids = Queue()
|
||||
self.__misses = 0
|
||||
self.__hits = 0
|
||||
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
|
||||
def disable(self) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
return
|
||||
self.__disabled = True
|
||||
|
||||
def enable(self) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
return
|
||||
self.__disabled = False
|
||||
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
return InvocationCacheStatus(
|
||||
hits=self.__hits,
|
||||
misses=self.__misses,
|
||||
enabled=not self.__disabled and self.__max_cache_size > 0,
|
||||
size=len(self.__cache),
|
||||
max_size=self.__max_cache_size,
|
||||
)
|
||||
|
||||
def _delete_by_match(self, to_match: str) -> None:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
keys_to_delete = set()
|
||||
for key, value_tuple in self.__cache.items():
|
||||
if to_match in value_tuple[1]:
|
||||
keys_to_delete.add(key)
|
||||
|
||||
if not keys_to_delete:
|
||||
return
|
||||
|
||||
for key in keys_to_delete:
|
||||
self.delete(key)
|
||||
|
||||
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")
|
||||
@@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -12,12 +12,15 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@@ -28,8 +31,8 @@ class InvocationServices:
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
@@ -37,6 +40,9 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,8 +50,8 @@ class InvocationServices:
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
@@ -53,10 +59,12 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
@@ -68,3 +76,6 @@ class InvocationServices:
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
|
||||
@@ -17,7 +17,14 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||
def invoke(
|
||||
self,
|
||||
session_queue_id: str,
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
@@ -32,7 +39,9 @@ class Invoker:
|
||||
# Queue the invocation
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
session_queue_id=session_queue_id,
|
||||
session_queue_item_id=session_queue_item_id,
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,6 +11,13 @@ import torch
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
@@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
@@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
@@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
@@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
@@ -537,6 +537,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -546,6 +549,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
@@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
@@ -56,6 +57,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
@@ -67,6 +71,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
@@ -79,6 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -89,13 +99,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -111,6 +125,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -138,6 +155,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -155,10 +175,19 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.session_queue_batch_id,
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -166,7 +195,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
@@ -25,6 +26,6 @@ class SimpleNameService(NameServiceBase):
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
uuid_str = uuid_string()
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
||||
|
||||
0
invokeai/app/services/session_processor/__init__.py
Normal file
0
invokeai/app/services/session_processor/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
"""
|
||||
Base class for session processor.
|
||||
|
||||
The session processor is responsible for executing sessions. It runs a simple polling loop,
|
||||
checking the session queue for new sessions to execute. It must coordinate with the
|
||||
invocation queue to ensure only one session is executing at a time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
"""Starts or resumes the session processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
"""Pauses the session processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
"""Gets the status of the session processor"""
|
||||
pass
|
||||
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionProcessorStatus(BaseModel):
|
||||
is_started: bool = Field(description="Whether the session processor is started")
|
||||
is_processing: bool = Field(description="Whether a session is being processed")
|
||||
@@ -0,0 +1,124 @@
|
||||
from threading import BoundedSemaphore
|
||||
from threading import Event as ThreadEvent
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker: Invoker = invoker
|
||||
self.__queue_item: Optional[SessionQueueItem] = None
|
||||
|
||||
self.__resume_event = ThreadEvent()
|
||||
self.__stop_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(
|
||||
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self.__poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
|
||||
"data"
|
||||
]["graph_execution_state_id"]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "batch_enqueued":
|
||||
self._poll_now()
|
||||
case "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self.__resume_event.is_set():
|
||||
self.__resume_event.set()
|
||||
return self.get_status()
|
||||
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
if self.__resume_event.is_set():
|
||||
self.__resume_event.clear()
|
||||
return self.get_status()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self.__resume_event.is_set(),
|
||||
is_processing=self.__queue_item is not None,
|
||||
)
|
||||
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
poll_now_event: ThreadEvent,
|
||||
resume_event: ThreadEvent,
|
||||
):
|
||||
try:
|
||||
stop_event.clear()
|
||||
resume_event.set()
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
self.__invoker.services.logger
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self.__queue_item = None
|
||||
self.__threadLimit.release()
|
||||
0
invokeai/app/services/session_queue/__init__.py
Normal file
0
invokeai/app/services/session_queue/__init__.py
Normal file
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SessionQueueBase(ABC):
|
||||
"""Base class for session queue"""
|
||||
|
||||
@abstractmethod
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
"""Dequeues the next session queue item."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||
"""Enqueues a single graph for execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently-executing session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next session queue item (does not dequeue it)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
"""Deletes all session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
"""Deletes all completed and errored session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
"""Checks if the queue is empty"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
"""Checks if the queue is empty"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
"""Gets the status of the queue"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets a page of session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
418
invokeai/app/services/session_queue/session_queue_common.py
Normal file
418
invokeai/app/services/session_queue/session_queue_common.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import datetime
|
||||
import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# region Errors
|
||||
|
||||
|
||||
class BatchZippedLengthError(ValueError):
|
||||
"""Raise when a batch has items of different lengths."""
|
||||
|
||||
|
||||
class BatchItemsTypeError(TypeError):
|
||||
"""Raise when a batch has items of different types."""
|
||||
|
||||
|
||||
class BatchDuplicateNodeFieldError(ValueError):
|
||||
"""Raise when a batch has duplicate node_path and field_name."""
|
||||
|
||||
|
||||
class TooManySessionsError(ValueError):
|
||||
"""Raise when too many sessions are requested."""
|
||||
|
||||
|
||||
class SessionQueueItemNotFoundError(ValueError):
|
||||
"""Raise when a queue item is not found."""
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
||||
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
||||
|
||||
|
||||
class BatchDatum(BaseModel):
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
)
|
||||
|
||||
|
||||
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
runs: int = Field(
|
||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||
)
|
||||
|
||||
@validator("data")
|
||||
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
||||
for i in batch_data_list:
|
||||
if len(i.items) != first_item_length:
|
||||
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
for item in datum.items:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
paths: set[tuple[str, str]] = set()
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
pair = (datum.node_path, datum.field_name)
|
||||
if pair in paths:
|
||||
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
||||
paths.add(pair)
|
||||
return v
|
||||
|
||||
@root_validator(skip_on_failure=True)
|
||||
def validate_batch_nodes_and_edges(cls, values):
|
||||
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
|
||||
if batch_data_collection is None:
|
||||
return values
|
||||
graph = cast(Graph, values["graph"])
|
||||
for batch_data_list in batch_data_collection:
|
||||
for batch_data in batch_data_list:
|
||||
try:
|
||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||
except NodeNotFoundError:
|
||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||
if batch_data.field_name not in node.__fields__:
|
||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||
return values
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"graph",
|
||||
"runs",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# endregion Batch
|
||||
|
||||
|
||||
# region Queue Items
|
||||
|
||||
DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
|
||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||
field_values_raw = queue_item_dict.get("field_values", None)
|
||||
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
|
||||
|
||||
|
||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
session_raw = queue_item_dict.get("session", "{}")
|
||||
return parse_raw_as(GraphExecutionState, session_raw)
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
item_id: int = Field(description="The identifier of the session queue item")
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||
default=None, description="The field values that were used for this queue item"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"session",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# endregion Queue Items
|
||||
|
||||
# region Query Results
|
||||
|
||||
|
||||
class SessionQueueStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
item_id: Optional[int] = Field(description="The current queue item id")
|
||||
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
||||
session_id: Optional[str] = Field(description="The current queue item's session id")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
|
||||
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
|
||||
|
||||
class EnqueueBatchResult(BaseModel):
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
|
||||
|
||||
class EnqueueGraphResult(BaseModel):
|
||||
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
|
||||
|
||||
|
||||
class ClearResult(BaseModel):
|
||||
"""Result of clearing the session queue"""
|
||||
|
||||
deleted: int = Field(..., description="Number of queue items deleted")
|
||||
|
||||
|
||||
class PruneResult(ClearResult):
|
||||
"""Result of pruning the session queue"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CancelByBatchIDsResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IsEmptyResult(BaseModel):
|
||||
"""Result of checking if the session queue is empty"""
|
||||
|
||||
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
||||
|
||||
|
||||
class IsFullResult(BaseModel):
|
||||
"""Result of checking if the session queue is full"""
|
||||
|
||||
is_full: bool = Field(..., description="Whether the session queue is full")
|
||||
|
||||
|
||||
# endregion Query Results
|
||||
|
||||
|
||||
# region Util
|
||||
|
||||
|
||||
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
||||
"""
|
||||
Populates the given graph with the given batch data items.
|
||||
"""
|
||||
graph_clone = graph.copy(deep=True)
|
||||
for item in node_field_values:
|
||||
node = graph_clone.get_node(item.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
setattr(node, item.field_name, item.value)
|
||||
graph_clone.update_node(item.node_path, node)
|
||||
return graph_clone
|
||||
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
that was applied to the graph.
|
||||
"""
|
||||
|
||||
# TODO: Should this be a class method on Batch?
|
||||
|
||||
data: list[list[tuple[NodeFieldValue]]] = []
|
||||
batch_data_collection = batch.data if batch.data is not None else []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||
|
||||
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
||||
for batch_datum in batch_datum_list:
|
||||
node_field_values = [
|
||||
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
data.append(list(zip(*node_field_values_to_zip)))
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
count = 0
|
||||
for _ in range(batch.runs):
|
||||
for d in product(*data):
|
||||
if count >= maximum:
|
||||
return
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring
|
||||
the overhead of actually generating them. Adapted from `create_sessions().
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
return batch.runs
|
||||
data = []
|
||||
for batch_datum_list in batch.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
values_to_insert.append(
|
||||
SessionQueueValueToInsert(
|
||||
queue_id, # queue_id
|
||||
session.json(), # session (json)
|
||||
session.id, # session_id
|
||||
batch.batch_id, # batch_id
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
|
||||
# endregion Util
|
||||
815
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
815
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
@@ -0,0 +1,815 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
DEFAULT_QUEUE_ID,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
__conn: sqlite3.Connection
|
||||
__cursor: sqlite3.Cursor
|
||||
__lock: threading.Lock
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self.__conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self.__conn.row_factory = sqlite3.Row
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self.__lock = lock
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
case "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, self.__cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||
|
||||
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND batch_id = ?
|
||||
""",
|
||||
(queue_id, enqueue_result.batch.batch_id),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||
return EnqueueGraphResult(
|
||||
**enqueue_result.dict(),
|
||||
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
|
||||
)
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def _set_queue_item_status(
|
||||
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?, error = ?
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(status, error, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id=item_id)
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return queue_item
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND batch_id IN ({placeholders})
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id is ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
try:
|
||||
item_id = cursor
|
||||
self.__lock.acquire()
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
item_id=current_item.item_id if current_item else None,
|
||||
session_id=current_item.session_id if current_item else None,
|
||||
batch_id=current_item.batch_id if current_item else None,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
14
invokeai/app/services/shared/models.py
Normal file
14
invokeai/app/services/shared/models.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
||||
|
||||
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
"""Cursor-paginated results"""
|
||||
|
||||
limit: int = Field(..., description="Limit of items to get")
|
||||
has_more: bool = Field(..., description="Whether there are more items available")
|
||||
items: list[GenericBaseModel] = Field(..., description="Items")
|
||||
@@ -1,5 +1,5 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
import threading
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
@@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._lock = lock
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
||||
3
invokeai/app/services/thread.py
Normal file
3
invokeai/app/services/thread.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -21,3 +22,8 @@ SEED_MAX = np.iinfo(np.uint32).max
|
||||
def get_random_seed():
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
@@ -110,6 +110,9 @@ def stable_diffusion_step_callback(
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
|
||||
46
invokeai/backend/image_util/invoke_metadata.py
Normal file
46
invokeai/backend/image_util/invoke_metadata.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""Very simple functions to fetch and print metadata from InvokeAI-generated images."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_invokeai_metadata(image_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve "invokeai_metadata" field from png image.
|
||||
|
||||
:param image_path: Path to the image to read metadata from.
|
||||
May raise:
|
||||
OSError -- image path not found
|
||||
KeyError -- image doesn't contain the metadata field
|
||||
"""
|
||||
image: Image = Image.open(image_path)
|
||||
return json.loads(image.text["invokeai_metadata"])
|
||||
|
||||
|
||||
def print_invokeai_metadata(image_path: Path):
|
||||
"""Pretty-print the metadata."""
|
||||
try:
|
||||
metadata = get_invokeai_metadata(image_path)
|
||||
print(f"{image_path}:\n{json.dumps(metadata, sort_keys=True, indent=4)}")
|
||||
except OSError:
|
||||
print(f"{image_path}:\nNo file found.")
|
||||
except KeyError:
|
||||
print(f"{image_path}:\nNo metadata found.")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the command-line utility."""
|
||||
image_paths = sys.argv[1:]
|
||||
if not image_paths:
|
||||
print(f"Usage: {Path(sys.argv[0]).name} image1 image2 image3 ...")
|
||||
print("\nPretty-print InvokeAI image metadata from the listed png files.")
|
||||
sys.exit(-1)
|
||||
for img in image_paths:
|
||||
print_invokeai_metadata(img)
|
||||
@@ -326,6 +326,16 @@ class ModelInstall(object):
|
||||
elif f"learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||
break
|
||||
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f"model.{suffix}" in files and "config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
@@ -534,14 +544,17 @@ def hf_download_with_resume(
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
with (
|
||||
open(model_dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
45
invokeai/backend/ip_adapter/README.md
Normal file
45
invokeai/backend/ip_adapter/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# IP-Adapter Model Formats
|
||||
|
||||
The official IP-Adapter models are released here: [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
|
||||
|
||||
This official model repo does not integrate well with InvokeAI's current approach to model management, so we have defined a new file structure for IP-Adapter models. The InvokeAI format is described below.
|
||||
|
||||
## CLIP Vision Models
|
||||
|
||||
CLIP Vision models are organized in `diffusers`` format. The expected directory structure is:
|
||||
|
||||
```bash
|
||||
ip_adapter_sd_image_encoder/
|
||||
├── config.json
|
||||
└── model.safetensors
|
||||
```
|
||||
|
||||
## IP-Adapter Models
|
||||
|
||||
IP-Adapter models are stored in a directory containing two files
|
||||
- `image_encoder.txt`: A text file containing the model identifier for the CLIP Vision encoder that is intended to be used with this IP-Adapter model.
|
||||
- `ip_adapter.bin`: The IP-Adapter weights.
|
||||
|
||||
Sample directory structure:
|
||||
```bash
|
||||
ip_adapter_sd15/
|
||||
├── image_encoder.txt
|
||||
└── ip_adapter.bin
|
||||
```
|
||||
|
||||
### Why save the weights in a .safetensors file?
|
||||
|
||||
The weights in `ip_adapter.bin` are stored in a nested dict, which is not supported by `safetensors`. This could be solved by splitting `ip_adapter.bin` into multiple files, but for now we have decided to maintain consistency with the checkpoint structure used in the official [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repo.
|
||||
|
||||
## InvokeAI Hosted IP-Adapters
|
||||
|
||||
Image Encoders:
|
||||
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
|
||||
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)
|
||||
|
||||
IP-Adapters:
|
||||
- [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
|
||||
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
|
||||
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
||||
0
invokeai/backend/ip_adapter/__init__.py
Normal file
0
invokeai/backend/ip_adapter/__init__.py
Normal file
162
invokeai/backend/ip_adapter/attention_processor.py
Normal file
162
invokeai/backend/ip_adapter/attention_processor.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
# tencent-ailab comment:
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||
|
||||
|
||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||
# loading.
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
# The batch dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if ip_hidden_states is not None:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
217
invokeai/backend/ip_adapter/ip_adapter.py
Normal file
217
invokeai/backend/ip_adapter/ip_adapter.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Image Projection Model"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
clip_extra_context_tokens (int, optional): Defaults to 4.
|
||||
|
||||
Returns:
|
||||
ImageProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||
)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: dict[torch.Tensor],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
):
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._num_tokens = num_tokens
|
||||
|
||||
self._clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
self._state_dict = state_dict
|
||||
|
||||
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||
|
||||
# The _attn_processors will be initialized later when we have access to the UNet.
|
||||
self._attn_processors = None
|
||||
|
||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
if self._attn_processors is not None:
|
||||
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
|
||||
attention weights into them.
|
||||
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
|
||||
intialized.
|
||||
"""
|
||||
attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
attn_procs[name] = IPAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
ip_layers = torch.nn.ModuleList(attn_procs.values())
|
||||
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
||||
self._attn_processors = attn_procs
|
||||
self._state_dict = None
|
||||
|
||||
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
|
||||
# which makes implementing begin_step_percent and end_step_percent easier
|
||||
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
|
||||
# which should make it easier to implement multiple IPAdapters
|
||||
def set_scale(self, scale):
|
||||
if self._attn_processors is not None:
|
||||
for attn_processor in self._attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
|
||||
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
if self._attn_processors is None:
|
||||
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
|
||||
# used on any UNet model (with the same dimensions).
|
||||
self._prepare_attention_processors(unet)
|
||||
|
||||
# Set scale
|
||||
self.set_scale(scale)
|
||||
# for attn_processor in self._attn_processors.values():
|
||||
# if isinstance(attn_processor, IPAttnProcessor2_0):
|
||||
# attn_processor.scale = scale
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
|
||||
# actually pops elements from the passed dict.
|
||||
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
|
||||
|
||||
try:
|
||||
unet.set_attn_processor(ip_adapter_attn_processors)
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self._num_tokens,
|
||||
ff_mult=4,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
||||
-2
|
||||
]
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
||||
# contains.
|
||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||
|
||||
if is_plus:
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
158
invokeai/backend/ip_adapter/resampler.py
Normal file
158
invokeai/backend/ip_adapter/resampler.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
|
||||
# tencent ailab comment: modified from
|
||||
# https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||
"""A convenience function that initializes a Resampler from a state_dict.
|
||||
|
||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
|
||||
possible if needed in the future.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict to load.
|
||||
depth (int, optional):
|
||||
dim_head (int, optional):
|
||||
heads (int, optional):
|
||||
ff_mult (int, optional):
|
||||
|
||||
Returns:
|
||||
Resampler
|
||||
"""
|
||||
dim = state_dict["latents"].shape[2]
|
||||
num_queries = state_dict["latents"].shape[1]
|
||||
embedding_dim = state_dict["proj_in.weight"].shape[-1]
|
||||
output_dim = state_dict["norm_out.weight"].shape[0]
|
||||
|
||||
model = cls(
|
||||
dim=dim,
|
||||
depth=depth,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=output_dim,
|
||||
ff_mult=ff_mult,
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
@@ -25,6 +25,7 @@ Models are described using four attributes:
|
||||
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
||||
ModelType.TextualInversion -- a textual inversion embedding
|
||||
ModelType.ControlNet -- a ControlNet model
|
||||
ModelType.IPAdapter -- an IPAdapter model
|
||||
|
||||
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
||||
BaseModelType.StableDiffusion1
|
||||
@@ -1000,8 +1001,8 @@ class ModelManager(object):
|
||||
new_models_found = True
|
||||
except DuplicateModelException as e:
|
||||
self.logger.warning(e)
|
||||
except InvalidModelException:
|
||||
self.logger.warning(f"Not a valid model: {model_path}")
|
||||
except InvalidModelException as e:
|
||||
self.logger.warning(f"Not a valid model: {model_path}. {e}")
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
@@ -8,6 +9,8 @@ import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
@@ -51,7 +54,9 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -118,14 +123,18 @@ class ModelProbe(object):
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else 768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
@@ -170,6 +179,7 @@ class ModelProbe(object):
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
@@ -177,9 +187,10 @@ class ModelProbe(object):
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
@@ -188,13 +199,24 @@ class ModelProbe(object):
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf["_class_name"]
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
@@ -366,6 +388,16 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@@ -438,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return (
|
||||
BaseModelType.StableDiffusionXL
|
||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
else BaseModelType.StableDiffusion1
|
||||
)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
@@ -485,11 +533,13 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
@@ -509,15 +559,47 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
||||
@@ -79,7 +79,7 @@ class ModelSearch(ABC):
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
@@ -90,7 +90,7 @@ class ModelSearch(ABC):
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
|
||||
@@ -18,7 +18,9 @@ from .base import ( # noqa: F401
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .clip_vision import CLIPVisionModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .ip_adapter import IPAdapterModel
|
||||
from .lora import LoRAModel
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
@@ -34,6 +36,8 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
@@ -42,6 +46,8 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@@ -51,6 +57,8 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@@ -60,6 +68,19 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.Any: {
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
# The following model types are not expected to be used with BaseModelType.Any.
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
},
|
||||
# BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
|
||||
@@ -36,6 +36,7 @@ class ModelNotFoundException(Exception):
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
Any = "any" # For models that are not associated with any particular base model.
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
@@ -50,6 +51,8 @@ class ModelType(str, Enum):
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
|
||||
82
invokeai/backend/model_management/models/clip_vision.py
Normal file
82
invokeai/backend/model_management/models/clip_vision.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class CLIPVisionModel(ModelBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[CLIPVisionModelFormat.Diffusers]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.CLIPVision
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||
return CLIPVisionModelFormat.Diffusers
|
||||
|
||||
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> CLIPVisionModelWithProjection:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
|
||||
|
||||
# Calculate a more accurate model size.
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == CLIPVisionModelFormat.Diffusers:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
92
invokeai/backend/model_management/models/ip_adapter.py
Normal file
92
invokeai/backend/model_management/models/ip_adapter.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterModelFormat(str, Enum):
|
||||
# The custom IP-Adapter model format defined by InvokeAI.
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class IPAdapterModel(ModelBase):
|
||||
class InvokeAIConfig(ModelConfigBase):
|
||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path):
|
||||
model_file = os.path.join(path, "ip_adapter.bin")
|
||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||
return IPAdapterModelFormat.InvokeAI
|
||||
|
||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == IPAdapterModelFormat.InvokeAI:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
|
||||
|
||||
def get_ip_adapter_image_encoder_model_id(model_path: str):
|
||||
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
|
||||
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
|
||||
|
||||
with open(image_encoder_config_file, "r") as f:
|
||||
image_encoder_model = f.readline().strip()
|
||||
|
||||
return image_encoder_model
|
||||
@@ -1,15 +1,6 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .diffusers_pipeline import ( # noqa: F401
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||
BasicConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import einops
|
||||
@@ -23,9 +23,11 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -95,7 +97,7 @@ class AddsMaskGuidance:
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{
|
||||
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
||||
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
|
||||
for k, v in step_output.items()
|
||||
}
|
||||
)
|
||||
@@ -162,39 +164,13 @@ class ControlNetData:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
# weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -277,6 +253,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
self.use_ip_adapter = False
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
@@ -349,6 +326,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
seed: Optional[int] = None,
|
||||
@@ -400,6 +378,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
@@ -419,6 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
@@ -431,12 +411,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=conditioning_data.extra,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
|
||||
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
|
||||
unet=self.invokeai_diffuser.model,
|
||||
scale=weight,
|
||||
)
|
||||
self.use_ip_adapter = True
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@@ -459,6 +453,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
@@ -504,6 +499,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@@ -514,6 +510,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# handle IP-Adapter
|
||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
||||
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
|
||||
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
|
||||
weight = (
|
||||
ip_adapter_data.weight[step_index]
|
||||
if isinstance(ip_adapter_data.weight, List)
|
||||
else ip_adapter_data.weight
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
|
||||
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
|
||||
ip_adapter_data.ip_adapter_model.set_scale(weight)
|
||||
else:
|
||||
# otherwise, set IP-Adapter scale to 0, so it has no effect
|
||||
ip_adapter_data.ip_adapter_model.set_scale(0.0)
|
||||
|
||||
# handle ControlNet(s)
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
|
||||
if control_data is not None:
|
||||
|
||||
@@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||
BasicConditioningInfo,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
||||
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .cross_attention_control import Arguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||
# should only be stored in one place.
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoder conditioning embeddings.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
"""
|
||||
uncond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
@@ -376,11 +376,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||
+ "work properly until it is fixed."
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
@@ -577,6 +577,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
**kwargs,
|
||||
):
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -10,9 +9,14 @@ from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
@@ -31,37 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@@ -75,15 +48,6 @@ class InvokeAIDiffuserComponent:
|
||||
debug_thresholding = False
|
||||
sequential_guidance = False
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@@ -103,30 +67,26 @@ class InvokeAIDiffuserComponent:
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
@@ -376,11 +336,24 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
||||
# fast batched path
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": torch.cat(
|
||||
[
|
||||
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
|
||||
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
@@ -408,6 +381,7 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
@@ -419,9 +393,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data,
|
||||
conditioning_data: ConditioningData,
|
||||
**kwargs,
|
||||
):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||
slower execution speed.
|
||||
"""
|
||||
# low-memory sequential path
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
@@ -437,6 +414,13 @@ class InvokeAIDiffuserComponent:
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# Run unconditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
@@ -449,12 +433,21 @@ class InvokeAIDiffuserComponent:
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Run conditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
@@ -465,6 +458,7 @@ class InvokeAIDiffuserComponent:
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
|
||||
568
invokeai/backend/util/db_maintenance.py
Normal file
568
invokeai/backend/util/db_maintenance.py
Normal file
@@ -0,0 +1,568 @@
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=broad-exception-caught
|
||||
# pylint: disable=missing-function-docstring
|
||||
"""Script to peform db maintenance and outputs directory management."""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import enum
|
||||
import glob
|
||||
import locale
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import PIL.PngImagePlugin
|
||||
import yaml
|
||||
|
||||
|
||||
class ConfigMapper:
|
||||
"""Configuration loader."""
|
||||
|
||||
def __init__(self): # noqa D107
|
||||
pass
|
||||
|
||||
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
INVOKE_DIRNAME = "invokeai"
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
archive_path = None
|
||||
thumbnails_path = None
|
||||
thumbnails_archive_path = None
|
||||
|
||||
def load(self):
|
||||
"""Read paths from yaml config and validate."""
|
||||
root = "."
|
||||
|
||||
if not self.__load_from_root_config(os.path.abspath(root)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __load_from_root_config(self, invoke_root):
|
||||
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
|
||||
|
||||
if db_dir is None or outdir is None:
|
||||
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
|
||||
return False
|
||||
|
||||
if os.path.isabs(db_dir):
|
||||
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
else:
|
||||
self.database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
|
||||
|
||||
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
|
||||
|
||||
if os.path.isabs(outdir):
|
||||
self.outputs_path = os.path.join(outdir, "images")
|
||||
self.archive_path = os.path.join(outdir, "images-archive")
|
||||
else:
|
||||
self.outputs_path = os.path.join(invoke_root, outdir, "images")
|
||||
self.archive_path = os.path.join(invoke_root, outdir, "images-archive")
|
||||
|
||||
self.thumbnails_path = os.path.join(self.outputs_path, "thumbnails")
|
||||
self.thumbnails_archive_path = os.path.join(self.archive_path, "thumbnails")
|
||||
|
||||
db_exists = os.path.exists(self.database_path)
|
||||
outdir_exists = os.path.exists(self.outputs_path)
|
||||
|
||||
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
|
||||
text += f"\n Database : {self.database_path} - {'Exists!' if db_exists else 'Not Found!'}"
|
||||
text += f"\n Outputs : {self.outputs_path}- {'Exists!' if outdir_exists else 'Not Found!'}"
|
||||
print(text)
|
||||
|
||||
if db_exists and outdir_exists:
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"\nOne or more paths specified in invoke.yaml do not exist. Please inspect/correct the configuration and ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
|
||||
)
|
||||
return False
|
||||
else:
|
||||
print(
|
||||
f"Auto-discovery of configuration failed! Could not find ({yaml_path})!\n\nPlease ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
|
||||
)
|
||||
return False
|
||||
|
||||
def __load_paths_from_yaml_file(self, yaml_path):
|
||||
"""Load an Invoke AI yaml file and get the database and outputs paths."""
|
||||
try:
|
||||
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
yamlinfo = yaml.safe_load(file)
|
||||
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
|
||||
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
|
||||
return db_dir, outdir
|
||||
except Exception:
|
||||
print(f"Failed to load paths from yaml file! {yaml_path}!")
|
||||
return None, None
|
||||
|
||||
|
||||
class MaintenanceStats:
|
||||
"""DTO for tracking work progress."""
|
||||
|
||||
def __init__(self): # noqa D107
|
||||
pass
|
||||
|
||||
time_start = datetime.datetime.utcnow()
|
||||
count_orphaned_db_entries_cleaned = 0
|
||||
count_orphaned_disk_files_cleaned = 0
|
||||
count_orphaned_thumbnails_cleaned = 0
|
||||
count_thumbnails_regenerated = 0
|
||||
count_errors = 0
|
||||
|
||||
@staticmethod
|
||||
def get_elapsed_time_string():
|
||||
"""Get a friendly time string for the time elapsed since processing start."""
|
||||
time_now = datetime.datetime.utcnow()
|
||||
total_seconds = (time_now - MaintenanceStats.time_start).total_seconds()
|
||||
hours = int((total_seconds) / 3600)
|
||||
minutes = int(((total_seconds) % 3600) / 60)
|
||||
seconds = total_seconds % 60
|
||||
out_str = f"{hours} hour(s) -" if hours > 0 else ""
|
||||
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
|
||||
out_str += f"{seconds:.2f} second(s)"
|
||||
return out_str
|
||||
|
||||
|
||||
class DatabaseMapper:
|
||||
"""Class to abstract database functionality."""
|
||||
|
||||
def __init__(self, database_path, database_backup_dir): # noqa D107
|
||||
self.database_path = database_path
|
||||
self.database_backup_dir = database_backup_dir
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
|
||||
def backup(self, timestamp_string):
|
||||
"""Take a backup of the database."""
|
||||
if not os.path.exists(self.database_backup_dir):
|
||||
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
|
||||
os.makedirs(self.database_backup_dir)
|
||||
print("Done!")
|
||||
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
|
||||
print(f"Making DB Backup at {database_backup_path}...", end="")
|
||||
shutil.copy2(self.database_path, database_backup_path)
|
||||
print("Done!")
|
||||
|
||||
def connect(self):
|
||||
"""Open connection to the database."""
|
||||
self.connection = sqlite3.connect(self.database_path)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def get_all_image_files(self):
|
||||
"""Get the full list of image file names from the database."""
|
||||
sql_get_image_by_name = "SELECT image_name FROM images"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
db_files = []
|
||||
for row in rows:
|
||||
db_files.append(row[0])
|
||||
return db_files
|
||||
|
||||
def remove_image_file_record(self, filename: str):
|
||||
"""Remove an image file reference from the database by filename."""
|
||||
sanitized_filename = str.replace(filename, "'", "''") # prevent injection
|
||||
sql_command = f"DELETE FROM images WHERE image_name='{sanitized_filename}'"
|
||||
self.cursor.execute(sql_command)
|
||||
self.connection.commit()
|
||||
|
||||
def does_image_exist(self, image_filename):
|
||||
"""Check database if a image name already exists and return a boolean."""
|
||||
sanitized_filename = str.replace(image_filename, "'", "''") # prevent injection
|
||||
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{sanitized_filename}'"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return True if len(rows) > 0 else False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the db, cleaning up connections and cursors."""
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class PhysicalFileMapper:
|
||||
"""Containing class for script functionality."""
|
||||
|
||||
def __init__(self, outputs_path, thumbnails_path, archive_path, thumbnails_archive_path): # noqa D107
|
||||
self.outputs_path = outputs_path
|
||||
self.archive_path = archive_path
|
||||
self.thumbnails_path = thumbnails_path
|
||||
self.thumbnails_archive_path = thumbnails_archive_path
|
||||
|
||||
def create_archive_directories(self):
|
||||
"""Create the directory for archiving orphaned image files."""
|
||||
if not os.path.exists(self.archive_path):
|
||||
print(f"Image archive directory ({self.archive_path}) does not exist -> creating...", end="")
|
||||
os.makedirs(self.archive_path)
|
||||
print("Created!")
|
||||
if not os.path.exists(self.thumbnails_archive_path):
|
||||
print(
|
||||
f"Image thumbnails archive directory ({self.thumbnails_archive_path}) does not exist -> creating...",
|
||||
end="",
|
||||
)
|
||||
os.makedirs(self.thumbnails_archive_path)
|
||||
print("Created!")
|
||||
|
||||
def get_image_path_for_image_name(self, image_filename): # noqa D102
|
||||
return os.path.join(self.outputs_path, image_filename)
|
||||
|
||||
def image_file_exists(self, image_filename): # noqa D102
|
||||
return os.path.exists(self.get_image_path_for_image_name(image_filename))
|
||||
|
||||
def get_thumbnail_path_for_image(self, image_filename): # noqa D102
|
||||
return os.path.join(self.thumbnails_path, os.path.splitext(image_filename)[0]) + ".webp"
|
||||
|
||||
def get_image_name_from_thumbnail_path(self, thumbnail_path): # noqa D102
|
||||
return os.path.splitext(os.path.basename(thumbnail_path))[0] + ".png"
|
||||
|
||||
def thumbnail_exists_for_filename(self, image_filename): # noqa D102
|
||||
return os.path.exists(self.get_thumbnail_path_for_image(image_filename))
|
||||
|
||||
def archive_image(self, image_filename): # noqa D102
|
||||
if self.image_file_exists(image_filename):
|
||||
image_path = self.get_image_path_for_image_name(image_filename)
|
||||
shutil.move(image_path, self.archive_path)
|
||||
|
||||
def archive_thumbnail_by_image_filename(self, image_filename): # noqa D102
|
||||
if self.thumbnail_exists_for_filename(image_filename):
|
||||
thumbnail_path = self.get_thumbnail_path_for_image(image_filename)
|
||||
shutil.move(thumbnail_path, self.thumbnails_archive_path)
|
||||
|
||||
def get_all_png_filenames_in_directory(self, directory_path): # noqa D102
|
||||
filepaths = glob.glob(directory_path + "/*.png", recursive=False)
|
||||
filenames = []
|
||||
for filepath in filepaths:
|
||||
filenames.append(os.path.basename(filepath))
|
||||
return filenames
|
||||
|
||||
def get_all_thumbnails_with_full_path(self, thumbnails_directory): # noqa D102
|
||||
return glob.glob(thumbnails_directory + "/*.webp", recursive=False)
|
||||
|
||||
def generate_thumbnail_for_image_name(self, image_filename): # noqa D102
|
||||
# create thumbnail
|
||||
file_path = self.get_image_path_for_image_name(image_filename)
|
||||
thumb_path = self.get_thumbnail_path_for_image(image_filename)
|
||||
thumb_size = 256, 256
|
||||
with PIL.Image.open(file_path) as source_image:
|
||||
source_image.thumbnail(thumb_size)
|
||||
source_image.save(thumb_path, "webp")
|
||||
|
||||
|
||||
class MaintenanceOperation(str, enum.Enum):
|
||||
"""Enum class for operations."""
|
||||
|
||||
Ask = "ask"
|
||||
CleanOrphanedDbEntries = "clean"
|
||||
CleanOrphanedDiskFiles = "archive"
|
||||
ReGenerateThumbnails = "thumbnails"
|
||||
All = "all"
|
||||
|
||||
|
||||
class InvokeAIDatabaseMaintenanceApp:
|
||||
"""Main processor class for the application."""
|
||||
|
||||
_operation: MaintenanceOperation
|
||||
_headless: bool = False
|
||||
__stats: MaintenanceStats = MaintenanceStats()
|
||||
|
||||
def __init__(self, operation: MaintenanceOperation = MaintenanceOperation.Ask):
|
||||
"""Initialize maintenance app."""
|
||||
self._operation = MaintenanceOperation(operation)
|
||||
self._headless = operation != MaintenanceOperation.Ask
|
||||
|
||||
def ask_for_operation(self) -> MaintenanceOperation:
|
||||
"""Ask user to choose the operation to perform."""
|
||||
while True:
|
||||
print()
|
||||
print("It is recommennded to run these operations as ordered below to avoid additional")
|
||||
print("work being performed that will be discarded in a subsequent step.")
|
||||
print()
|
||||
print("Select maintenance operation:")
|
||||
print()
|
||||
print("1) Clean Orphaned Database Image Entries")
|
||||
print(" Cleans entries in the database where the matching file was removed from")
|
||||
print(" the outputs directory.")
|
||||
print("2) Archive Orphaned Image Files")
|
||||
print(" Files found in the outputs directory without an entry in the database are")
|
||||
print(" moved to an archive directory.")
|
||||
print("3) Re-Generate Missing Thumbnail Files")
|
||||
print(" For files found in the outputs directory, re-generate a thumbnail if it")
|
||||
print(" not found in the thumbnails directory.")
|
||||
print()
|
||||
print("(CTRL-C to quit)")
|
||||
|
||||
try:
|
||||
input_option = int(input("Specify desired operation number (1-3): "))
|
||||
|
||||
operations = [
|
||||
MaintenanceOperation.CleanOrphanedDbEntries,
|
||||
MaintenanceOperation.CleanOrphanedDiskFiles,
|
||||
MaintenanceOperation.ReGenerateThumbnails,
|
||||
]
|
||||
return operations[input_option - 1]
|
||||
except (IndexError, ValueError):
|
||||
print("\nInvalid selection!")
|
||||
|
||||
def ask_to_continue(self) -> bool:
|
||||
"""Ask user whether they want to continue with the operation."""
|
||||
while True:
|
||||
input_choice = input("Do you wish to continue? (Y or N)? ")
|
||||
if str.lower(input_choice) == "y":
|
||||
return True
|
||||
if str.lower(input_choice) == "n":
|
||||
return False
|
||||
|
||||
def clean_orphaned_db_entries(
|
||||
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
|
||||
):
|
||||
"""Clean dangling database entries that no longer point to a file in outputs."""
|
||||
if self._headless:
|
||||
print(f"Removing database references to images that no longer exist in {config.outputs_path}...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Clean Orphaned Database Entries")
|
||||
print()
|
||||
print("Perform this operation if you have removed files from the outputs/images")
|
||||
print("directory but the database was never updated. You may see this as empty imaages")
|
||||
print("in the app gallery, or images that only show an enlarged version of the")
|
||||
print("thumbnail.")
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Archive Directory : {config.archive_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find database image file entries that do not exist in the")
|
||||
print(" outputs/images dir and remove those entries from the database.")
|
||||
print("- This operation will target all image types including intermediate files.")
|
||||
print("- If a thumbnail still exists in outputs/images/thumbnails matching the")
|
||||
print(" orphaned entry, it will be moved to the archive directory.")
|
||||
print()
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
file_mapper.create_archive_directories()
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
db_mapper.connect()
|
||||
db_files = db_mapper.get_all_image_files()
|
||||
for db_file in db_files:
|
||||
try:
|
||||
if not file_mapper.image_file_exists(db_file):
|
||||
print(f"Found orphaned image db entry {db_file}. Cleaning ...", end="")
|
||||
db_mapper.remove_image_file_record(db_file)
|
||||
print("Cleaned!")
|
||||
if file_mapper.thumbnail_exists_for_filename(db_file):
|
||||
print("A thumbnail was found, archiving ...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(db_file)
|
||||
print("Archived!")
|
||||
self.__stats.count_orphaned_db_entries_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("An error occurred cleaning db entry, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def clean_orphaned_disk_files(
|
||||
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
|
||||
):
|
||||
"""Archive image files that no longer have entries in the database."""
|
||||
if self._headless:
|
||||
print(f"Archiving orphaned image files to {config.archive_path}...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Clean Orphaned Disk Files")
|
||||
print()
|
||||
print("Perform this operation if you have files that were copied into the outputs")
|
||||
print("directory which are not referenced by the database. This can happen if you")
|
||||
print("upgraded to a version with a fresh database, but re-used the outputs directory")
|
||||
print("and now new images are mixed with the files not in the db. The script will")
|
||||
print("archive these files so you can choose to delete them or re-import using the")
|
||||
print("official import script.")
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Archive Directory : {config.archive_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find image files not referenced by the database and move to an")
|
||||
print(" archive directory.")
|
||||
print("- This operation will target all image types including intermediate references.")
|
||||
print("- The matching thumbnail will also be archived.")
|
||||
print("- Any remaining orphaned thumbnails will also be archived.")
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
print()
|
||||
|
||||
file_mapper.create_archive_directories()
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
db_mapper.connect()
|
||||
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
|
||||
for phys_file in phys_files:
|
||||
try:
|
||||
if not db_mapper.does_image_exist(phys_file):
|
||||
print(f"Found orphaned file {phys_file}, archiving...", end="")
|
||||
file_mapper.archive_image(phys_file)
|
||||
print("Archived!")
|
||||
if file_mapper.thumbnail_exists_for_filename(phys_file):
|
||||
print("Related thumbnail exists, archiving...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(phys_file)
|
||||
print("Archived!")
|
||||
else:
|
||||
print("No matching thumbnail existed to be cleaned.")
|
||||
self.__stats.count_orphaned_disk_files_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to archive file or thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
thumb_filepaths = file_mapper.get_all_thumbnails_with_full_path(config.thumbnails_path)
|
||||
# archive any remaining orphaned thumbnails
|
||||
for thumb_filepath in thumb_filepaths:
|
||||
try:
|
||||
thumb_src_image_name = file_mapper.get_image_name_from_thumbnail_path(thumb_filepath)
|
||||
if not file_mapper.image_file_exists(thumb_src_image_name):
|
||||
print(f"Found orphaned thumbnail {thumb_filepath}, archiving...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(thumb_src_image_name)
|
||||
print("Archived!")
|
||||
self.__stats.count_orphaned_thumbnails_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to archive thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def regenerate_thumbnails(self, config: ConfigMapper, file_mapper: PhysicalFileMapper, *args):
|
||||
"""Create missing thumbnails for any valid general images both in the db and on disk."""
|
||||
if self._headless:
|
||||
print("Regenerating missing image thumbnails...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Regenerate Thumbnails")
|
||||
print()
|
||||
print("This operation will find files that have no matching thumbnail on disk")
|
||||
print("and regenerate those thumbnail files.")
|
||||
print("NOTE: It is STRONGLY recommended that the user first clean/archive orphaned")
|
||||
print(" disk files from the previous menu to avoid wasting time regenerating")
|
||||
print(" thumbnails for orphaned files.")
|
||||
|
||||
print()
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Directory : {config.thumbnails_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find image files both referenced in the db and on disk")
|
||||
print(" that do not have a matching thumbnail on disk and re-generate the thumbnail")
|
||||
print(" file.")
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
print()
|
||||
|
||||
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
|
||||
for phys_file in phys_files:
|
||||
try:
|
||||
if not file_mapper.thumbnail_exists_for_filename(phys_file):
|
||||
print(f"Found file without thumbnail {phys_file}...Regenerating Thumbnail...", end="")
|
||||
file_mapper.generate_thumbnail_for_image_name(phys_file)
|
||||
print("Done!")
|
||||
self.__stats.count_thumbnails_regenerated += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to regenerate thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def main(self): # noqa D107
|
||||
print("\n===============================================================================")
|
||||
print("Database and outputs Maintenance for Invoke AI 3.0.0 +")
|
||||
print("===============================================================================\n")
|
||||
|
||||
config_mapper = ConfigMapper()
|
||||
if not config_mapper.load():
|
||||
print("\nInvalid configuration...exiting.\n")
|
||||
return
|
||||
|
||||
file_mapper = PhysicalFileMapper(
|
||||
config_mapper.outputs_path,
|
||||
config_mapper.thumbnails_path,
|
||||
config_mapper.archive_path,
|
||||
config_mapper.thumbnails_archive_path,
|
||||
)
|
||||
db_mapper = DatabaseMapper(config_mapper.database_path, config_mapper.database_backup_dir)
|
||||
|
||||
op = self._operation
|
||||
operations_to_perform = []
|
||||
|
||||
if op == MaintenanceOperation.Ask:
|
||||
op = self.ask_for_operation()
|
||||
|
||||
if op in [MaintenanceOperation.CleanOrphanedDbEntries, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.clean_orphaned_db_entries)
|
||||
if op in [MaintenanceOperation.CleanOrphanedDiskFiles, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.clean_orphaned_disk_files)
|
||||
if op in [MaintenanceOperation.ReGenerateThumbnails, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.regenerate_thumbnails)
|
||||
|
||||
for operation in operations_to_perform:
|
||||
operation(config_mapper, file_mapper, db_mapper)
|
||||
|
||||
print("\n===============================================================================")
|
||||
print(f"= Maintenance Complete - Elapsed Time: {MaintenanceStats.get_elapsed_time_string()}")
|
||||
print()
|
||||
print(f"Orphaned db entries cleaned : {self.__stats.count_orphaned_db_entries_cleaned}")
|
||||
print(f"Orphaned disk files archived : {self.__stats.count_orphaned_disk_files_cleaned}")
|
||||
print(f"Orphaned thumbnail files archived : {self.__stats.count_orphaned_thumbnails_cleaned}")
|
||||
print(f"Thumbnails regenerated : {self.__stats.count_thumbnails_regenerated}")
|
||||
print(f"Errors during operation : {self.__stats.count_errors}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main(): # noqa D107
|
||||
parser = argparse.ArgumentParser(
|
||||
description="InvokeAI image database maintenance utility",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""Operations:
|
||||
ask Choose operation from a menu [default]
|
||||
all Run all maintenance operations
|
||||
clean Clean database of dangling entries
|
||||
archive Archive orphaned image files
|
||||
thumbnails Regenerate missing image thumbnails
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--root", default=".", type=Path, help="InvokeAI root directory")
|
||||
parser.add_argument(
|
||||
"--operation", default="ask", choices=[x.value for x in MaintenanceOperation], help="Operation to perform."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
os.chdir(args.root)
|
||||
app = InvokeAIDatabaseMaintenanceApp(args.operation)
|
||||
app.main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nUser cancelled execution.")
|
||||
except FileNotFoundError:
|
||||
print(f"Invalid root directory '{args.root}'.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,7 +14,6 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import PIL
|
||||
@@ -27,6 +26,7 @@ from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.shortcuts import message_dialog
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
@@ -421,7 +421,7 @@ VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{me
|
||||
return rows[0][0]
|
||||
else:
|
||||
board_date_string = datetime.datetime.utcnow().date().isoformat()
|
||||
new_board_id = str(uuid.uuid4())
|
||||
new_board_id = uuid_string()
|
||||
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
|
||||
self.cursor.execute(sql_insert_board)
|
||||
self.connection.commit()
|
||||
|
||||
169
invokeai/frontend/web/dist/assets/App-61ae6f9e.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-61ae6f9e.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-d1567775.js
vendored
169
invokeai/frontend/web/dist/assets/App-d1567775.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
310
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-9f721f69.js
vendored
Normal file
310
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-9f721f69.js
vendored
Normal file
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-eac60e23.js
vendored
Normal file
128
invokeai/frontend/web/dist/assets/index-eac60e23.js
vendored
Normal file
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-f83c2c5c.js
vendored
128
invokeai/frontend/web/dist/assets/index-f83c2c5c.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-f40e1624.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-f40e1624.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f83c2c5c.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-eac60e23.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
1822
invokeai/frontend/web/dist/locales/en.json
vendored
1822
invokeai/frontend/web/dist/locales/en.json
vendored
File diff suppressed because it is too large
Load Diff
@@ -13,14 +13,15 @@
|
||||
"reset": "Reset",
|
||||
"rotateClockwise": "Rotate Clockwise",
|
||||
"rotateCounterClockwise": "Rotate Counter-Clockwise",
|
||||
"showGallery": "Show Gallery",
|
||||
"showGalleryPanel": "Show Gallery Panel",
|
||||
"showOptionsPanel": "Show Side Panel",
|
||||
"toggleAutoscroll": "Toggle autoscroll",
|
||||
"toggleLogViewer": "Toggle Log Viewer",
|
||||
"uploadImage": "Upload Image",
|
||||
"useThisParameter": "Use this parameter",
|
||||
"zoomIn": "Zoom In",
|
||||
"zoomOut": "Zoom Out"
|
||||
"zoomOut": "Zoom Out",
|
||||
"loadMore": "Load More"
|
||||
},
|
||||
"boards": {
|
||||
"addBoard": "Add Board",
|
||||
@@ -49,6 +50,7 @@
|
||||
"close": "Close",
|
||||
"communityLabel": "Community",
|
||||
"controlNet": "Controlnet",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"darkMode": "Dark Mode",
|
||||
"discordLabel": "Discord",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
@@ -109,6 +111,7 @@
|
||||
"statusModelChanged": "Model Changed",
|
||||
"statusModelConverted": "Model Converted",
|
||||
"statusPreparing": "Preparing",
|
||||
"statusProcessing": "Processing",
|
||||
"statusProcessingCanceled": "Processing Canceled",
|
||||
"statusProcessingComplete": "Processing Complete",
|
||||
"statusRestoringFaces": "Restoring Faces",
|
||||
@@ -191,13 +194,92 @@
|
||||
"showAdvanced": "Show Advanced",
|
||||
"toggleControlNet": "Toggle this ControlNet",
|
||||
"w": "W",
|
||||
"weight": "Weight"
|
||||
"weight": "Weight",
|
||||
"enableIPAdapter": "Enable IP Adapter",
|
||||
"ipAdapterModel": "Adapter Model",
|
||||
"resetIPAdapterImage": "Reset IP Adapter Image",
|
||||
"ipAdapterImageFallback": "No IP Adapter Image Selected"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
"noMatchingEmbedding": "No matching Embeddings"
|
||||
},
|
||||
"queue": {
|
||||
"queue": "Queue",
|
||||
"queueFront": "Add to Front of Queue",
|
||||
"queueBack": "Add to Queue",
|
||||
"queueCountPrediction": "Add {{predicted}} to Queue",
|
||||
"queueMaxExceeded": "Max of {{max_queue_size}} exceeded, would skip {{skip}}",
|
||||
"queuedCount": "{{pending}} Pending",
|
||||
"queueTotal": "{{total}} Total",
|
||||
"queueEmpty": "Queue Empty",
|
||||
"enqueueing": "Queueing Batch",
|
||||
"resume": "Resume",
|
||||
"resumeTooltip": "Resume Processor",
|
||||
"resumeSucceeded": "Processor Resumed",
|
||||
"resumeFailed": "Problem Resuming Processor",
|
||||
"pause": "Pause",
|
||||
"pauseTooltip": "Pause Processor",
|
||||
"pauseSucceeded": "Processor Paused",
|
||||
"pauseFailed": "Problem Pausing Processor",
|
||||
"cancel": "Cancel",
|
||||
"cancelTooltip": "Cancel Current Item",
|
||||
"cancelSucceeded": "Item Canceled",
|
||||
"cancelFailed": "Problem Canceling Item",
|
||||
"prune": "Prune",
|
||||
"pruneTooltip": "Prune {{item_count}} Completed Items",
|
||||
"pruneSucceeded": "Pruned {{item_count}} Completed Items from Queue",
|
||||
"pruneFailed": "Problem Pruning Queue",
|
||||
"clear": "Clear",
|
||||
"clearTooltip": "Cancel and Clear All Items",
|
||||
"clearSucceeded": "Queue Cleared",
|
||||
"clearFailed": "Problem Clearing Queue",
|
||||
"cancelBatch": "Cancel Batch",
|
||||
"cancelItem": "Cancel Item",
|
||||
"cancelBatchSucceeded": "Batch Canceled",
|
||||
"cancelBatchFailed": "Problem Canceling Batch",
|
||||
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely.",
|
||||
"clearQueueAlertDialog2": "Are you sure you want to clear the queue?",
|
||||
"current": "Current",
|
||||
"next": "Next",
|
||||
"status": "Status",
|
||||
"total": "Total",
|
||||
"pending": "Pending",
|
||||
"in_progress": "In Progress",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"canceled": "Canceled",
|
||||
"completedIn": "Completed in",
|
||||
"batch": "Batch",
|
||||
"item": "Item",
|
||||
"session": "Session",
|
||||
"batchValues": "Batch Values",
|
||||
"notReady": "Unable to Queue",
|
||||
"batchQueued": "Batch Queued",
|
||||
"batchQueuedDesc": "Added {{item_count}} sessions to {{direction}} of queue",
|
||||
"front": "front",
|
||||
"back": "back",
|
||||
"batchFailedToQueue": "Failed to Queue Batch",
|
||||
"graphQueued": "Graph queued",
|
||||
"graphFailedToQueue": "Failed to queue graph"
|
||||
},
|
||||
"invocationCache": {
|
||||
"invocationCache": "Invocation Cache",
|
||||
"cacheSize": "Cache Size",
|
||||
"maxCacheSize": "Max Cache Size",
|
||||
"hits": "Cache Hits",
|
||||
"misses": "Cache Misses",
|
||||
"clear": "Clear",
|
||||
"clearSucceeded": "Invocation Cache Cleared",
|
||||
"clearFailed": "Problem Clearing Invocation Cache",
|
||||
"enable": "Enable",
|
||||
"enableSucceeded": "Invocation Cache Enabled",
|
||||
"enableFailed": "Problem Enabling Invocation Cache",
|
||||
"disable": "Disable",
|
||||
"disableSucceeded": "Invocation Cache Disabled",
|
||||
"disableFailed": "Problem Disabling Invocation Cache"
|
||||
},
|
||||
"gallery": {
|
||||
"allImagesLoaded": "All Images Loaded",
|
||||
"assets": "Assets",
|
||||
@@ -636,7 +718,8 @@
|
||||
"collectionItemDescription": "TODO",
|
||||
"colorCodeEdges": "Color-Code Edges",
|
||||
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
|
||||
"colorCollectionDescription": "A collection of colors.",
|
||||
"colorCollection": "A collection of colors.",
|
||||
"colorCollectionDescription": "TODO",
|
||||
"colorField": "Color",
|
||||
"colorFieldDescription": "A RGBA color.",
|
||||
"colorPolymorphic": "Color Polymorphic",
|
||||
@@ -683,7 +766,8 @@
|
||||
"imageFieldDescription": "Images may be passed between nodes.",
|
||||
"imagePolymorphic": "Image Polymorphic",
|
||||
"imagePolymorphicDescription": "A collection of images.",
|
||||
"inputFields": "Input Feilds",
|
||||
"inputField": "Input Field",
|
||||
"inputFields": "Input Fields",
|
||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||
"inputNode": "Input Node",
|
||||
"integer": "Integer",
|
||||
@@ -701,6 +785,7 @@
|
||||
"latentsPolymorphicDescription": "Latents may be passed between nodes.",
|
||||
"loadingNodes": "Loading Nodes...",
|
||||
"loadWorkflow": "Load Workflow",
|
||||
"noWorkflow": "No Workflow",
|
||||
"loRAModelField": "LoRA",
|
||||
"loRAModelFieldDescription": "TODO",
|
||||
"mainModelField": "Model",
|
||||
@@ -722,14 +807,15 @@
|
||||
"noImageFoundState": "No initial image found in state",
|
||||
"noMatchingNodes": "No matching nodes",
|
||||
"noNodeSelected": "No node selected",
|
||||
"noOpacity": "Node Opacity",
|
||||
"nodeOpacity": "Node Opacity",
|
||||
"noOutputRecorded": "No outputs recorded",
|
||||
"noOutputSchemaName": "No output schema name found in ref object",
|
||||
"notes": "Notes",
|
||||
"notesDescription": "Add notes about your workflow",
|
||||
"oNNXModelField": "ONNX Model",
|
||||
"oNNXModelFieldDescription": "ONNX model field.",
|
||||
"outputFields": "Output Feilds",
|
||||
"outputField": "Output Field",
|
||||
"outputFields": "Output Fields",
|
||||
"outputNode": "Output node",
|
||||
"outputSchemaNotFound": "Output schema not found",
|
||||
"pickOne": "Pick One",
|
||||
@@ -778,6 +864,7 @@
|
||||
"unknownNode": "Unknown Node",
|
||||
"unknownTemplate": "Unknown Template",
|
||||
"unkownInvocation": "Unknown Invocation type",
|
||||
"updateNode": "Update Node",
|
||||
"updateApp": "Update App",
|
||||
"vaeField": "Vae",
|
||||
"vaeFieldDescription": "Vae submodel.",
|
||||
@@ -814,6 +901,7 @@
|
||||
},
|
||||
"cfgScale": "CFG Scale",
|
||||
"clipSkip": "CLIP Skip",
|
||||
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
||||
"closeViewer": "Close Viewer",
|
||||
"codeformerFidelity": "Fidelity",
|
||||
"coherenceMode": "Mode",
|
||||
@@ -852,6 +940,7 @@
|
||||
"noInitialImageSelected": "No initial image selected",
|
||||
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
|
||||
"noModelSelected": "No model selected",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"readyToInvoke": "Ready to Invoke",
|
||||
"systemBusy": "System busy",
|
||||
@@ -870,7 +959,12 @@
|
||||
"perlinNoise": "Perlin Noise",
|
||||
"positivePromptPlaceholder": "Positive Prompt",
|
||||
"randomizeSeed": "Randomize Seed",
|
||||
"manualSeed": "Manual Seed",
|
||||
"randomSeed": "Random Seed",
|
||||
"restoreFaces": "Restore Faces",
|
||||
"iterations": "Iterations",
|
||||
"iterationsWithCount_one": "{{count}} Iteration",
|
||||
"iterationsWithCount_other": "{{count}} Iterations",
|
||||
"scale": "Scale",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
"scaledHeight": "Scaled H",
|
||||
@@ -881,13 +975,17 @@
|
||||
"seamlessTiling": "Seamless Tiling",
|
||||
"seamlessXAxis": "X Axis",
|
||||
"seamlessYAxis": "Y Axis",
|
||||
"seamlessX": "Seamless X",
|
||||
"seamlessY": "Seamless Y",
|
||||
"seamlessX&Y": "Seamless X & Y",
|
||||
"seamLowThreshold": "Low",
|
||||
"seed": "Seed",
|
||||
"seedWeights": "Seed Weights",
|
||||
"imageActions": "Image Actions",
|
||||
"sendTo": "Send to",
|
||||
"sendToImg2Img": "Send to Image to Image",
|
||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"showOptionsPanel": "Show Side Panel (O or T)",
|
||||
"showPreview": "Show Preview",
|
||||
"shuffle": "Shuffle Seed",
|
||||
"steps": "Steps",
|
||||
@@ -896,11 +994,13 @@
|
||||
"tileSize": "Tile Size",
|
||||
"toggleLoopback": "Toggle Loopback",
|
||||
"type": "Type",
|
||||
"upscale": "Upscale",
|
||||
"upscale": "Upscale (Shift + U)",
|
||||
"upscaleImage": "Upscale Image",
|
||||
"upscaling": "Upscaling",
|
||||
"useAll": "Use All",
|
||||
"useCpuNoise": "Use CPU Noise",
|
||||
"cpuNoise": "CPU Noise",
|
||||
"gpuNoise": "GPU Noise",
|
||||
"useInitImg": "Use Initial Image",
|
||||
"usePrompt": "Use Prompt",
|
||||
"useSeed": "Use Seed",
|
||||
@@ -909,11 +1009,20 @@
|
||||
"vSymmetryStep": "V Symmetry Step",
|
||||
"width": "Width"
|
||||
},
|
||||
"prompt": {
|
||||
"dynamicPrompts": {
|
||||
"combinatorial": "Combinatorial Generation",
|
||||
"dynamicPrompts": "Dynamic Prompts",
|
||||
"enableDynamicPrompts": "Enable Dynamic Prompts",
|
||||
"maxPrompts": "Max Prompts"
|
||||
"maxPrompts": "Max Prompts",
|
||||
"promptsWithCount_one": "{{count}} Prompt",
|
||||
"promptsWithCount_other": "{{count}} Prompts",
|
||||
"seedBehaviour": {
|
||||
"label": "Seed Behaviour",
|
||||
"perIterationLabel": "Seed per Iteration",
|
||||
"perIterationDesc": "Use a different seed for each iteration",
|
||||
"perPromptLabel": "Seed per Prompt",
|
||||
"perPromptDesc": "Use a different seed for each prompt"
|
||||
}
|
||||
},
|
||||
"sdxl": {
|
||||
"cfgScale": "CFG Scale",
|
||||
@@ -1036,6 +1145,7 @@
|
||||
"serverError": "Server Error",
|
||||
"setCanvasInitialImage": "Set as canvas initial image",
|
||||
"setControlImage": "Set as control image",
|
||||
"setIPAdapterImage": "Set as IP Adapter Image",
|
||||
"setInitialImage": "Set as initial image",
|
||||
"setNodeField": "Set as node field",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
@@ -1060,6 +1170,136 @@
|
||||
"variations": "Try a variation with a value between 0.1 and 1.0 to change the result for a given seed. Interesting variations of the seed are between 0.1 and 0.3."
|
||||
}
|
||||
},
|
||||
"popovers": {
|
||||
"clipSkip": {
|
||||
"heading": "CLIP Skip",
|
||||
"paragraph": "Choose how many layers of the CLIP model to skip. Certain models are better suited to be used with CLIP Skip."
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Blur",
|
||||
"paragraph": "The blur radius of the mask."
|
||||
},
|
||||
"compositingBlurMethod": {
|
||||
"heading": "Blur Method",
|
||||
"paragraph": "The method of blur applied to the masked area."
|
||||
},
|
||||
"compositingCoherencePass": {
|
||||
"heading": "Coherence Pass",
|
||||
"paragraph": "Composite the Inpainted/Outpainted images."
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Mode",
|
||||
"paragraph": "The mode of the Coherence Pass."
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Steps",
|
||||
"paragraph": "Number of steps in the Coherence Pass. Similar to Denoising Steps."
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Strength",
|
||||
"paragraph": "Amount of noise added for the Coherence Pass. Similar to Denoising Strength."
|
||||
},
|
||||
"compositingMaskAdjustments": {
|
||||
"heading": "Mask Adjustments",
|
||||
"paragraph": "Adjust the mask."
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"heading": "Begin / End Step Percentage",
|
||||
"paragraph": "Which parts of the denoising process will have the ControlNet applied. ControlNets applied at the start of the process guide composition, and ControlNets applied at the end guide details."
|
||||
},
|
||||
"controlNetControlMode": {
|
||||
"heading": "Control Mode",
|
||||
"paragraph": "Lends more weight to either the prompt or ControlNet."
|
||||
},
|
||||
"controlNetResizeMode": {
|
||||
"heading": "Resize Mode",
|
||||
"paragraph": "How the ControlNet image will be fit to the image generation Ratio"
|
||||
},
|
||||
"controlNetToggle": {
|
||||
"heading": "Enable ControlNet",
|
||||
"paragraph": "ControlNets provide guidance to the generation process, helping create images with controlled composition, structure, or style, depending on the model selected."
|
||||
},
|
||||
"controlNetWeight": {
|
||||
"heading": "Weight",
|
||||
"paragraph": "How strongly the ControlNet will impact the generated image."
|
||||
},
|
||||
"dynamicPromptsToggle": {
|
||||
"heading": "Enable Dynamic Prompts",
|
||||
"paragraph": "Dynamic prompts allow multiple options within a prompt. Dynamic prompts can be used by: {option1|option2|option3}. Combinations of prompts will be randomly generated until the “Images” number has been reached."
|
||||
},
|
||||
"dynamicPromptsCombinatorial": {
|
||||
"heading": "Combinatorial Generation",
|
||||
"paragraph": "Generate an image for every possible combination of Dynamic Prompts until the Max Prompts is reached."
|
||||
},
|
||||
"infillMethod": {
|
||||
"heading": "Infill Method",
|
||||
"paragraph": "Method to infill the selected area."
|
||||
},
|
||||
"lora": {
|
||||
"heading": "LoRA Weight",
|
||||
"paragraph": "Weight of the LoRA. Higher weight will lead to larger impacts on the final image."
|
||||
},
|
||||
"noiseEnable": {
|
||||
"heading": "Enable Noise Settings",
|
||||
"paragraph": "Advanced control over noise generation."
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
"heading": "Use CPU Noise",
|
||||
"paragraph": "Uses the CPU to generate random noise."
|
||||
},
|
||||
"paramCFGScale": {
|
||||
"heading": "CFG Scale",
|
||||
"paragraph": "Controls how much your prompt influences the generation process."
|
||||
},
|
||||
"paramDenoisingStrength": {
|
||||
"heading": "Denoising Strength",
|
||||
"paragraph": "How much noise is added to the input image. 0 will result in an identical image, while 1 will result in a completely new image."
|
||||
},
|
||||
"paramIterations": {
|
||||
"heading": "Iterations",
|
||||
"paragraph": "The number of images to generate. If Dynamic Prompts is enabled, each of the prompts will be generated this many times."
|
||||
},
|
||||
"paramModel": {
|
||||
"heading": "Model",
|
||||
"paragraph": "Model used for the denoising steps. Different models are trained to specialize in producing different aesthetic results and content."
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"heading": "Negative Prompt",
|
||||
"paragraph": "The generation process avoids the concepts in the negative prompt. Use this to exclude qualities or objects from the output. Supports Compel syntax and embeddings."
|
||||
},
|
||||
"paramPositiveConditioning": {
|
||||
"heading": "Positive Prompt",
|
||||
"paragraph": "Guides the generation process. You may use any words or phrases. Supports Compel and Dynamic Prompts syntaxes and embeddings."
|
||||
},
|
||||
"paramRatio": {
|
||||
"heading": "Ratio",
|
||||
"paragraph": "The ratio of the dimensions of the image generated. An image size (in number of pixels) equivalent to 512x512 is recommended for SD1.5 models and a size equivalent to 1024x1024 is recommended for SDXL models."
|
||||
},
|
||||
"paramScheduler": {
|
||||
"heading": "Scheduler",
|
||||
"paragraph": "Scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output."
|
||||
},
|
||||
"paramSeed": {
|
||||
"heading": "Seed",
|
||||
"paragraph": "Controls the starting noise used for generation. Disable “Random Seed” to produce identical results with the same generation settings."
|
||||
},
|
||||
"paramSteps": {
|
||||
"heading": "Steps",
|
||||
"paragraph": "Number of steps that will be performed in each generation. Higher step counts will typically create better images but will require more generation time."
|
||||
},
|
||||
"paramVAE": {
|
||||
"heading": "VAE",
|
||||
"paragraph": "Model used for translating AI output into the final image."
|
||||
},
|
||||
"paramVAEPrecision": {
|
||||
"heading": "VAE Precision",
|
||||
"paragraph": "The precision used during VAE encoding and decoding. Fp16/Half precision is more efficient, at the expense of minor image variations."
|
||||
},
|
||||
"scaleBeforeProcessing": {
|
||||
"heading": "Scale Before Processing",
|
||||
"paragraph": "Scales the selected area to the size best suited for the model before the image generation process."
|
||||
}
|
||||
},
|
||||
"ui": {
|
||||
"hideProgressImages": "Hide Progress Images",
|
||||
"lockRatio": "Lock Ratio",
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(systemSelector, (system) => {
|
||||
const { isUploading } = system;
|
||||
|
||||
let tooltip = '';
|
||||
|
||||
if (isUploading) {
|
||||
tooltip = 'Uploading...';
|
||||
}
|
||||
|
||||
return {
|
||||
tooltip,
|
||||
shouldShow: isUploading,
|
||||
};
|
||||
});
|
||||
|
||||
export const AuxiliaryProgressIndicator = () => {
|
||||
const { shouldShow, tooltip } = useAppSelector(selector);
|
||||
|
||||
if (!shouldShow) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
color: 'base.600',
|
||||
}}
|
||||
>
|
||||
<Tooltip label={tooltip} placement="right" hasArrow>
|
||||
<Spinner />
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AuxiliaryProgressIndicator);
|
||||
@@ -1,6 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useQueueBack } from 'features/queue/hooks/useQueueBack';
|
||||
import { useQueueFront } from 'features/queue/hooks/useQueueFront';
|
||||
import {
|
||||
ctrlKeyPressed,
|
||||
metaKeyPressed,
|
||||
@@ -33,6 +35,39 @@ const globalHotkeysSelector = createSelector(
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
|
||||
const {
|
||||
queueBack,
|
||||
isDisabled: isDisabledQueueBack,
|
||||
isLoading: isLoadingQueueBack,
|
||||
} = useQueueBack();
|
||||
|
||||
useHotkeys(
|
||||
['ctrl+enter', 'meta+enter'],
|
||||
queueBack,
|
||||
{
|
||||
enabled: () => !isDisabledQueueBack && !isLoadingQueueBack,
|
||||
preventDefault: true,
|
||||
enableOnFormTags: ['input', 'textarea', 'select'],
|
||||
},
|
||||
[queueBack, isDisabledQueueBack, isLoadingQueueBack]
|
||||
);
|
||||
|
||||
const {
|
||||
queueFront,
|
||||
isDisabled: isDisabledQueueFront,
|
||||
isLoading: isLoadingQueueFront,
|
||||
} = useQueueFront();
|
||||
|
||||
useHotkeys(
|
||||
['ctrl+shift+enter', 'meta+shift+enter'],
|
||||
queueFront,
|
||||
{
|
||||
enabled: () => !isDisabledQueueFront && !isLoadingQueueFront,
|
||||
preventDefault: true,
|
||||
enableOnFormTags: ['input', 'textarea', 'select'],
|
||||
},
|
||||
[queueFront, isDisabledQueueFront, isLoadingQueueFront]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'*',
|
||||
|
||||
@@ -17,6 +17,7 @@ import '../../i18n';
|
||||
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||
import { $customStarUI, CustomStarUi } from 'app/store/nanostores/customStarUI';
|
||||
import { $headerComponent } from 'app/store/nanostores/headerComponent';
|
||||
import { $queueId, DEFAULT_QUEUE_ID } from 'features/queue/store/nanoStores';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
@@ -28,6 +29,7 @@ interface Props extends PropsWithChildren {
|
||||
headerComponent?: ReactNode;
|
||||
middleware?: Middleware[];
|
||||
projectId?: string;
|
||||
queueId?: string;
|
||||
selectedImage?: {
|
||||
imageName: string;
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
@@ -42,6 +44,7 @@ const InvokeAIUI = ({
|
||||
headerComponent,
|
||||
middleware,
|
||||
projectId,
|
||||
queueId,
|
||||
selectedImage,
|
||||
customStarUi,
|
||||
}: Props) => {
|
||||
@@ -61,6 +64,11 @@ const InvokeAIUI = ({
|
||||
$projectId.set(projectId);
|
||||
}
|
||||
|
||||
// configure API client project header
|
||||
if (queueId) {
|
||||
$queueId.set(queueId);
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
resetMiddlewares();
|
||||
|
||||
@@ -81,8 +89,9 @@ const InvokeAIUI = ({
|
||||
$baseUrl.set(undefined);
|
||||
$authToken.set(undefined);
|
||||
$projectId.set(undefined);
|
||||
$queueId.set(DEFAULT_QUEUE_ID);
|
||||
};
|
||||
}, [apiUrl, token, middleware, projectId]);
|
||||
}, [apiUrl, token, middleware, projectId, queueId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (customStarUi) {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useToast } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
@@ -11,7 +10,7 @@ import { memo, useCallback, useEffect } from 'react';
|
||||
*/
|
||||
const Toaster = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toastQueue = useAppSelector(toastQueueSelector);
|
||||
const toastQueue = useAppSelector((state) => state.system.toastQueue);
|
||||
const toast = useToast();
|
||||
useEffect(() => {
|
||||
toastQueue.forEach((t) => {
|
||||
|
||||
@@ -20,6 +20,7 @@ export type LoggerNamespace =
|
||||
| 'system'
|
||||
| 'socketio'
|
||||
| 'session'
|
||||
| 'queue'
|
||||
| 'dnd';
|
||||
|
||||
export const logger = (namespace: LoggerNamespace) =>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
@@ -14,8 +14,8 @@ import {
|
||||
} from './logger';
|
||||
|
||||
const selector = createSelector(
|
||||
systemSelector,
|
||||
(system) => {
|
||||
stateSelector,
|
||||
({ system }) => {
|
||||
const { consoleLogLevel, shouldLogToConsole } = system;
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { BatchConfig } from 'services/api/types';
|
||||
|
||||
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
|
||||
export const enqueueRequested = createAction<{
|
||||
tabName: InvokeTabName;
|
||||
prepend: boolean;
|
||||
}>('app/enqueueRequested');
|
||||
|
||||
export const batchEnqueued = createAction<BatchConfig>('app/batchEnqueued');
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
|
||||
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
|
||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
@@ -20,6 +21,7 @@ const serializationDenylist: {
|
||||
system: systemPersistDenylist,
|
||||
ui: uiPersistDenylist,
|
||||
controlNet: controlNetDenylist,
|
||||
dynamicPrompts: dynamicPromptsPersistDenylist,
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user