mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 01:37:56 -05:00
Compare commits
260 Commits
lstein/mod
...
onnx-testi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8299d0abb | ||
|
|
a28ab654ef | ||
|
|
8699fd7050 | ||
|
|
9e65470ada | ||
|
|
f4e52fafac | ||
|
|
ee7b36cea5 | ||
|
|
487455ef2e | ||
|
|
632346b2e2 | ||
|
|
e201ad2f51 | ||
|
|
0f18898865 | ||
|
|
700131fab2 | ||
|
|
634d6bb8a6 | ||
|
|
2fbf245c3d | ||
|
|
39c14eb2ac | ||
|
|
c89fa4b635 | ||
|
|
e943913f58 | ||
|
|
4ada094c5c | ||
|
|
893e199677 | ||
|
|
3c5a0c95b3 | ||
|
|
71a07ee5a7 | ||
|
|
2a0a765ec4 | ||
|
|
ec08151009 | ||
|
|
186e98da5e | ||
|
|
dea9a5da7a | ||
|
|
bda0000acd | ||
|
|
4b678f2416 | ||
|
|
43fbbfb848 | ||
|
|
b71bcab691 | ||
|
|
869f418b03 | ||
|
|
c364c85915 | ||
|
|
3773bfbc74 | ||
|
|
949437b4f0 | ||
|
|
efcb3a9d08 | ||
|
|
35d5ef9118 | ||
|
|
b4eeaaa63c | ||
|
|
54bd7c7f04 | ||
|
|
3240f98f4e | ||
|
|
3d4cef0099 | ||
|
|
3fa7170566 | ||
|
|
187d7c1cab | ||
|
|
7fde1f93ea | ||
|
|
9685760fac | ||
|
|
3f1d5000c0 | ||
|
|
ae5cb63f3c | ||
|
|
0c18c5d603 | ||
|
|
7d49c727a0 | ||
|
|
889b77d3d6 | ||
|
|
fbbc4b3f69 | ||
|
|
bc11296a5e | ||
|
|
4e13d3408f | ||
|
|
c19d48abd0 | ||
|
|
b0fb4950ed | ||
|
|
42c440c73f | ||
|
|
8bc1fe38b3 | ||
|
|
65df821233 | ||
|
|
68f2fb2601 | ||
|
|
f9459d650e | ||
|
|
bd4eaa455a | ||
|
|
b59784e521 | ||
|
|
769df47863 | ||
|
|
1cab89fe8c | ||
|
|
e31b2a6ff4 | ||
|
|
1c1a72f4c4 | ||
|
|
5ac6076944 | ||
|
|
9c3c393b84 | ||
|
|
112937f1f8 | ||
|
|
5d635c7221 | ||
|
|
e6bfc382a5 | ||
|
|
f970e3792f | ||
|
|
3ffca5490e | ||
|
|
f803d5cf1e | ||
|
|
ab2343da51 | ||
|
|
4975b1a704 | ||
|
|
e1b756658a | ||
|
|
d17450bbe6 | ||
|
|
64d676219b | ||
|
|
416afd2781 | ||
|
|
afa84a149c | ||
|
|
be659364c2 | ||
|
|
56098f370c | ||
|
|
99383c2701 | ||
|
|
6e40b543cd | ||
|
|
f287c0174b | ||
|
|
c955c13b6f | ||
|
|
ef31837167 | ||
|
|
3d1ad86e8a | ||
|
|
b08ad28daa | ||
|
|
6c03d9f8f2 | ||
|
|
9e01a13d63 | ||
|
|
73eeef34c4 | ||
|
|
1353bf98b3 | ||
|
|
e74eac5c91 | ||
|
|
47617b8f63 | ||
|
|
9c2a2b313e | ||
|
|
32662c5ee8 | ||
|
|
a61540859e | ||
|
|
c16325a244 | ||
|
|
7221a238b3 | ||
|
|
af1c1ab51f | ||
|
|
e7443867f6 | ||
|
|
025cda3815 | ||
|
|
84275a3f12 | ||
|
|
c5b5195f40 | ||
|
|
d661bf832d | ||
|
|
d45ff7e100 | ||
|
|
9dbffadc6e | ||
|
|
11882173e3 | ||
|
|
990f34aa15 | ||
|
|
f7de000e79 | ||
|
|
04c0700762 | ||
|
|
5b7eef3d43 | ||
|
|
13da881953 | ||
|
|
c3a7e35ad8 | ||
|
|
53db91ef99 | ||
|
|
ec3c15ead0 | ||
|
|
0edb31febd | ||
|
|
a137f7fe7b | ||
|
|
179455ef46 | ||
|
|
6eaa7d212d | ||
|
|
7c3eb06a71 | ||
|
|
6d688ca87d | ||
|
|
715e3217d0 | ||
|
|
72c1a8db08 | ||
|
|
337399ff7c | ||
|
|
fbc0694527 | ||
|
|
47b1a85e70 | ||
|
|
ccf093b189 | ||
|
|
ada9b06e48 | ||
|
|
7ec1be80ad | ||
|
|
6ae10798b0 | ||
|
|
ded5ebc745 | ||
|
|
65ed43afb9 | ||
|
|
3f8e978543 | ||
|
|
0c9c7591c6 | ||
|
|
0fce35c54c | ||
|
|
c82ae74610 | ||
|
|
380aa1d7b5 | ||
|
|
81ccbc5c6a | ||
|
|
bcce70fca6 | ||
|
|
1c680a7147 | ||
|
|
dcd7e01908 | ||
|
|
fca6a5dd3c | ||
|
|
e03e43281b | ||
|
|
08854b6d68 | ||
|
|
0712294c17 | ||
|
|
0ea8d3c30c | ||
|
|
84a13ff8e1 | ||
|
|
3fba262c94 | ||
|
|
107ca6bf47 | ||
|
|
1d3fda80aa | ||
|
|
e039771d07 | ||
|
|
cfdaa30d44 | ||
|
|
3e2a948007 | ||
|
|
af9e8fefce | ||
|
|
ba12849685 | ||
|
|
f398fe4136 | ||
|
|
41e7b008fb | ||
|
|
98e6a56714 | ||
|
|
cbd5be73d2 | ||
|
|
38e6e3b36b | ||
|
|
c9233eeca2 | ||
|
|
540f40c293 | ||
|
|
641b90cc3f | ||
|
|
aebd595607 | ||
|
|
ccb43d5a91 | ||
|
|
ce58c41553 | ||
|
|
9b55eea673 | ||
|
|
d9a853857c | ||
|
|
036e5d7292 | ||
|
|
b4e09d4143 | ||
|
|
bc3aab93f1 | ||
|
|
2bc3e36bc0 | ||
|
|
cad3f96831 | ||
|
|
6534288b75 | ||
|
|
0a2964d8c0 | ||
|
|
932112b640 | ||
|
|
dabd2bf301 | ||
|
|
91112167b1 | ||
|
|
5206ddf9b2 | ||
|
|
92029e69c6 | ||
|
|
5351171d0e | ||
|
|
5b047baeb0 | ||
|
|
fe78a08e37 | ||
|
|
d93d42af4a | ||
|
|
b767b5d44c | ||
|
|
c9c2229917 | ||
|
|
421fcb761b | ||
|
|
2e0370d845 | ||
|
|
72c891bbac | ||
|
|
39e66ec934 | ||
|
|
eda1c94bd6 | ||
|
|
e95cb3aa71 | ||
|
|
ab840742b0 | ||
|
|
be0603b64c | ||
|
|
5b5d5ec978 | ||
|
|
ccbfa5d862 | ||
|
|
c487166d9c | ||
|
|
7b6159f8d6 | ||
|
|
cd033f4ead | ||
|
|
b1e16aa3db | ||
|
|
e1c0ca1ab2 | ||
|
|
dcbb3dc49a | ||
|
|
4a2f34f77f | ||
|
|
558c26d78f | ||
|
|
9769b48661 | ||
|
|
8c8eddcc60 | ||
|
|
79ca0d0d02 | ||
|
|
690331b8c0 | ||
|
|
bd7b59910d | ||
|
|
9fb0b0959f | ||
|
|
d8f88c09ea | ||
|
|
524888bf3b | ||
|
|
b444b8db25 | ||
|
|
75c5ce46bc | ||
|
|
358ced6bab | ||
|
|
34cff848c7 | ||
|
|
4d9a342437 | ||
|
|
7ce43692c2 | ||
|
|
23d8a2777e | ||
|
|
8e42502dfd | ||
|
|
d8ebbd258a | ||
|
|
bf2b5b5cd4 | ||
|
|
130249a2dd | ||
|
|
b17406a985 | ||
|
|
f7d8ae20a6 | ||
|
|
0327eae509 | ||
|
|
bb85608890 | ||
|
|
6c7668aaca | ||
|
|
7759b3f75a | ||
|
|
4d337f6abc | ||
|
|
92c86fd0b8 | ||
|
|
46dc751139 | ||
|
|
4cefe37723 | ||
|
|
82b73c50a0 | ||
|
|
7df7a95299 | ||
|
|
85b4b359c2 | ||
|
|
cfe81b5e00 | ||
|
|
b0c4451324 | ||
|
|
d4931522d4 | ||
|
|
17e2a35228 | ||
|
|
91016d8b29 | ||
|
|
9fda21cf40 | ||
|
|
809ec7163e | ||
|
|
7c9a939b47 | ||
|
|
9634c96020 | ||
|
|
e0c105f413 | ||
|
|
f0bf32c476 | ||
|
|
28373dbb98 | ||
|
|
4133d77772 | ||
|
|
61c426f502 | ||
|
|
bf0577c882 | ||
|
|
24673fd859 | ||
|
|
dc669d1447 | ||
|
|
ce4110b9f4 | ||
|
|
0f3b7d2b3d | ||
|
|
16dc78f6c6 | ||
|
|
7a66856785 | ||
|
|
c8dfa49d86 | ||
|
|
76dd749b1e | ||
|
|
67d05d2066 |
2
.github/workflows/mkdocs-material.yml
vendored
2
.github/workflows/mkdocs-material.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
@@ -617,8 +617,6 @@ sections describe what's new for InvokeAI.
|
||||
- `dream.py` script renamed `invoke.py`. A `dream.py` script wrapper remains for
|
||||
backward compatibility.
|
||||
- Completely new WebGUI - launch with `python3 scripts/invoke.py --web`
|
||||
- Support for [inpainting](deprecated/INPAINTING.md) and
|
||||
[outpainting](features/OUTPAINTING.md)
|
||||
- img2img runs on all k\* samplers
|
||||
- Support for
|
||||
[negative prompts](features/PROMPTS.md#negative-and-unconditioned-prompts)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 983 KiB After Width: | Height: | Size: 1.1 MiB |
@@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
|
||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||
|
||||
--8<-- "invokeai/frontend/web/README.md"
|
||||
|
||||
## Developing InvokeAI in VSCode
|
||||
|
||||
VSCode offers some nice tools:
|
||||
|
||||
- python debugger
|
||||
- automatic `venv` activation
|
||||
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
|
||||
comfort on your macbook)
|
||||
|
||||
### Setup
|
||||
|
||||
You'll need the
|
||||
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
|
||||
and
|
||||
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
|
||||
extensions installed first.
|
||||
|
||||
It's also really handy to install the `Jupyter` extensions:
|
||||
|
||||
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
|
||||
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
|
||||
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
|
||||
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
|
||||
|
||||
#### InvokeAI workspace
|
||||
|
||||
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
|
||||
can hold InvokeAI-specific settings and configs.
|
||||
|
||||
To make a workspace:
|
||||
|
||||
- Open the InvokeAI repo dir in VSCode
|
||||
- `File` > `Save Workspace As` > save it _outside_ the repo
|
||||
|
||||
#### Default python interpreter (i.e. automatic virtual environment activation)
|
||||
|
||||
- Use command palette to run command
|
||||
`Preferences: Open Workspace Settings (JSON)`
|
||||
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
|
||||
python
|
||||
|
||||
Should look something like this:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
// I like to have all InvokeAI-related folders in my workspace
|
||||
"folders": [
|
||||
{
|
||||
// repo root
|
||||
"path": "InvokeAI"
|
||||
},
|
||||
{
|
||||
// InvokeAI root dir, where `invokeai.yaml` lives
|
||||
"path": "/path/to/invokeai_root"
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
// Where your InvokeAI `venv`'s python executable lives
|
||||
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Now when you open the VSCode integrated terminal, or do anything that needs to
|
||||
run python, it will automatically be in your InvokeAI virtual environment.
|
||||
|
||||
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
|
||||
for the python interpreter to run in. This will default to your `venv` python,
|
||||
and so you'll have access to the same python environment as the InvokeAI app.
|
||||
|
||||
This is _super_ handy.
|
||||
|
||||
#### Debugging configs with `launch.json`
|
||||
|
||||
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
|
||||
these can be scoped to a workspace or folder.
|
||||
|
||||
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
|
||||
to set up your `launch.json` and try it out.
|
||||
|
||||
Now we can create the InvokeAI debugging configs:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
// Run the InvokeAI backend & serve the pre-built UI
|
||||
"name": "InvokeAI Web",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "scripts/invokeai-web.py",
|
||||
"args": [
|
||||
// Your InvokeAI root dir (where `invokeai.yaml` lives)
|
||||
"--root",
|
||||
"/path/to/invokeai_root",
|
||||
// Access the app from anywhere on your local network
|
||||
"--host",
|
||||
"0.0.0.0"
|
||||
],
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
// Run the nodes-based CLI
|
||||
"name": "InvokeAI CLI",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "scripts/invokeai-cli.py",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
// Run tests
|
||||
"name": "InvokeAI Test",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": ["--capture=no"],
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
// Run a single test
|
||||
"name": "InvokeAI Single Test",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
// Change this to point to the specific test you are working on
|
||||
"tests/nodes/test_invoker.py"
|
||||
],
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
// This is the default, useful to just run a single file
|
||||
"name": "Python: File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
You'll see these configs in the debugging configs drop down. Running them will
|
||||
start InvokeAI with attached debugger, in the correct environment, and work just
|
||||
like the normal app.
|
||||
|
||||
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
|
||||
|
||||
#### Remote dev
|
||||
|
||||
This is very easy to set up and provides the same very smooth experience as
|
||||
local development. Environments and debugging, as set up above, just work,
|
||||
though you'd need to recreate the workspace and debugging configs on the remote.
|
||||
|
||||
Consult the
|
||||
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
|
||||
get it set up.
|
||||
|
||||
Suggest using VSCode's included settings sync so that your remote dev host has
|
||||
all the same app settings and extensions automagically.
|
||||
|
||||
##### One remote dev gotcha
|
||||
|
||||
I've found the automatic port forwarding to be very flakey. You can disable it
|
||||
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
|
||||
`remote.autoForwardPorts` and untick the box.
|
||||
|
||||
To forward ports very reliably, use SSH on the remote dev client (e.g. your
|
||||
macbook). Here's how to forward both backend API port (`9090`) and the frontend
|
||||
live dev server port (`5173`):
|
||||
|
||||
```bash
|
||||
ssh \
|
||||
-L 9090:localhost:9090 \
|
||||
-L 5173:localhost:5173 \
|
||||
user@remote-dev-host
|
||||
```
|
||||
|
||||
The forwarding stops when you close the terminal window, so suggest to do this
|
||||
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
|
||||
an extension update or something
|
||||
|
||||
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
|
||||
now served from the remote dev host, just the same as if it was running on the
|
||||
client.
|
||||
|
||||
@@ -76,10 +76,10 @@ From top to bottom, these are:
|
||||
with outpainting,and modify interior portions of the image with
|
||||
inpainting, erase portions of a starting image and have the AI fill in
|
||||
the erased region from a text prompt.
|
||||
4. Workflow Management (not yet implemented) - this panel will allow you to create
|
||||
4. Node Editor - this panel allows you to create
|
||||
pipelines of common operations and combine them into workflows.
|
||||
5. Training (not yet implemented) - this panel will provide an interface to [textual
|
||||
inversion training](TEXTUAL_INVERSION.md) and fine tuning.
|
||||
5. Model Manager - this panel allows you to import and configure new
|
||||
models using URLs, local paths, or HuggingFace diffusers repo_ids.
|
||||
|
||||
The inpainting, outpainting and postprocessing tabs are currently in
|
||||
development. However, limited versions of their features can already be accessed
|
||||
|
||||
@@ -37,7 +37,7 @@ guide also covers optimizing models to load quickly.
|
||||
Teach an old model new tricks. Merge 2-3 models together to create a
|
||||
new model that combines characteristics of the originals.
|
||||
|
||||
## * [Textual Inversion](TEXTUAL_INVERSION.md)
|
||||
## * [Textual Inversion](TRAINING.md)
|
||||
Personalize models by adding your own style or subjects.
|
||||
|
||||
# Other Features
|
||||
|
||||
@@ -146,7 +146,6 @@ This method is recommended for those familiar with running Docker containers
|
||||
- [Installing](installation/050_INSTALLING_MODELS.md)
|
||||
- [Model Merging](features/MODEL_MERGING.md)
|
||||
- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md)
|
||||
- [Textual Inversion](features/TEXTUAL_INVERSION.md)
|
||||
- [Not Safe for Work (NSFW) Checker](features/NSFW.md)
|
||||
<!-- seperator -->
|
||||
### Prompt Engineering
|
||||
|
||||
@@ -354,8 +354,8 @@ experimental versions later.
|
||||
|
||||
12. **InvokeAI Options**: You can launch InvokeAI with several different command-line arguments that
|
||||
customize its behavior. For example, you can change the location of the
|
||||
image output directory, or select your favorite sampler. See the
|
||||
[Command-Line Interface](../features/CLI.md) for a full list of the options.
|
||||
image output directory or balance memory usage vs performance. See
|
||||
[Configuration](../features/CONFIGURATION.md) for a full list of the options.
|
||||
|
||||
- To set defaults that will take effect every time you launch InvokeAI,
|
||||
use a text editor (e.g. Notepad) to exit the file
|
||||
|
||||
@@ -256,7 +256,7 @@ manager, please follow these steps:
|
||||
|
||||
10. Render away!
|
||||
|
||||
Browse the [features](../features/CLI.md) section to learn about all the
|
||||
Browse the [features](../features/index.md) section to learn about all the
|
||||
things you can do with InvokeAI.
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ manager, please follow these steps:
|
||||
|
||||
12. Other scripts
|
||||
|
||||
The [Textual Inversion](../features/TEXTUAL_INVERSION.md) script can be launched with the command:
|
||||
The [Textual Inversion](../features/TRAINING.md) script can be launched with the command:
|
||||
|
||||
```bash
|
||||
invokeai-ti --gui
|
||||
|
||||
@@ -43,24 +43,7 @@ InvokeAI comes with support for a good set of starter models. You'll
|
||||
find them listed in the master models file
|
||||
`configs/INITIAL_MODELS.yaml` in the InvokeAI root directory. The
|
||||
subset that are currently installed are found in
|
||||
`configs/models.yaml`. As of v2.3.1, the list of starter models is:
|
||||
|
||||
|Model Name | HuggingFace Repo ID | Description | URL |
|
||||
|---------- | ---------- | ----------- | --- |
|
||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||
|dreamlike-photoreal-2.0|dreamlike-art/dreamlike-photoreal-2.0|A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)|https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0 |
|
||||
|inkpunk-1.0|Envvi/Inkpunk-Diffusion|Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)|https://huggingface.co/Envvi/Inkpunk-Diffusion |
|
||||
|openjourney-4.0|prompthero/openjourney|An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)|https://huggingface.co/prompthero/openjourney |
|
||||
|portrait-plus-1.0|wavymulder/portraitplus|An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)|https://huggingface.co/wavymulder/portraitplus |
|
||||
|seek-art-mega-1.0|coreco/seek.art_MEGA|A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)|https://huggingface.co/coreco/seek.art_MEGA |
|
||||
|trinart-2.0|naclbit/trinart_stable_diffusion_v2|An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)|https://huggingface.co/naclbit/trinart_stable_diffusion_v2 |
|
||||
|waifu-diffusion-1.4|hakurei/waifu-diffusion|An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)|https://huggingface.co/hakurei/waifu-diffusion |
|
||||
`configs/models.yaml`.
|
||||
|
||||
Note that these files are covered by an "Ethical AI" license which
|
||||
forbids certain uses. When you initially download them, you are asked
|
||||
@@ -71,8 +54,7 @@ with the model terms by visiting the URLs in the table above.
|
||||
|
||||
## Community-Contributed Models
|
||||
|
||||
There are too many to list here and more are being contributed every
|
||||
day. [HuggingFace](https://huggingface.co/models?library=diffusers)
|
||||
[HuggingFace](https://huggingface.co/models?library=diffusers)
|
||||
is a great resource for diffusers models, and is also the home of a
|
||||
[fast-growing repository](https://huggingface.co/sd-concepts-library)
|
||||
of embedding (".bin") models that add subjects and/or styles to your
|
||||
@@ -86,310 +68,106 @@ only `.safetensors` and `.ckpt` models, but they can be easily loaded
|
||||
into InvokeAI and/or converted into optimized `diffusers` models. Be
|
||||
aware that CIVITAI hosts many models that generate NSFW content.
|
||||
|
||||
!!! note
|
||||
|
||||
InvokeAI 2.3.x does not support directly importing and
|
||||
running Stable Diffusion version 2 checkpoint models. You may instead
|
||||
convert them into `diffusers` models using the conversion methods
|
||||
described below.
|
||||
|
||||
## Installation
|
||||
|
||||
There are multiple ways to install and manage models:
|
||||
There are two ways to install and manage models:
|
||||
|
||||
1. The `invokeai-configure` script which will download and install them for you.
|
||||
1. The `invokeai-model-install` script which will download and install
|
||||
them for you. In addition to supporting main models, you can install
|
||||
ControlNet, LoRA and Textual Inversion models.
|
||||
|
||||
2. The command-line tool (CLI) has commands that allows you to import, configure and modify
|
||||
models files.
|
||||
|
||||
3. The web interface (WebUI) has a GUI for importing and managing
|
||||
2. The web interface (WebUI) has a GUI for importing and managing
|
||||
models.
|
||||
|
||||
### Installation via `invokeai-configure`
|
||||
3. By placing models (or symbolic links to models) inside one of the
|
||||
InvokeAI root directory's `autoimport` folder.
|
||||
|
||||
From the `invoke` launcher, choose option (6) "re-run the configure
|
||||
script to download new models." This will launch the same script that
|
||||
prompted you to select models at install time. You can use this to add
|
||||
models that you skipped the first time around. It is all right to
|
||||
specify a model that was previously downloaded; the script will just
|
||||
confirm that the files are complete.
|
||||
### Installation via `invokeai-model-install`
|
||||
|
||||
### Installation via the CLI
|
||||
From the `invoke` launcher, choose option [5] "Download and install
|
||||
models." This will launch the same script that prompted you to select
|
||||
models at install time. You can use this to add models that you
|
||||
skipped the first time around. It is all right to specify a model that
|
||||
was previously downloaded; the script will just confirm that the files
|
||||
are complete.
|
||||
|
||||
You can install a new model, including any of the community-supported ones, via
|
||||
the command-line client's `!import_model` command.
|
||||
The installer has different panels for installing main models from
|
||||
HuggingFace, models from Civitai and other arbitrary web sites,
|
||||
ControlNet models, LoRA/LyCORIS models, and Textual Inversion
|
||||
embeddings. Each section has a text box in which you can enter a new
|
||||
model to install. You can refer to a model using its:
|
||||
|
||||
#### Installing individual `.ckpt` and `.safetensors` models
|
||||
1. Local path to the .ckpt, .safetensors or diffusers folder on your local machine
|
||||
2. A directory on your machine that contains multiple models
|
||||
3. A URL that points to a downloadable model
|
||||
4. A HuggingFace repo id
|
||||
|
||||
If the model is already downloaded to your local disk, use
|
||||
`!import_model /path/to/file.ckpt` to load it. For example:
|
||||
Previously-installed models are shown with checkboxes. Uncheck a box
|
||||
to unregister the model from InvokeAI. Models that are physically
|
||||
installed inside the InvokeAI root directory will be deleted and
|
||||
purged (after a confirmation warning). Models that are located outside
|
||||
the InvokeAI root directory will be unregistered but not deleted.
|
||||
|
||||
```bash
|
||||
invoke> !import_model C:/Users/fred/Downloads/martians.safetensors
|
||||
Note: The installer script uses a console-based text interface that requires
|
||||
significant amounts of horizontal and vertical space. If the display
|
||||
looks messed up, just enlarge the terminal window and/or relaunch the
|
||||
script.
|
||||
|
||||
If you wish you can script model addition and deletion, as well as
|
||||
listing installed models. Start the "developer's console" and give the
|
||||
command `invokeai-model-install --help`. This will give you a series
|
||||
of command-line parameters that will let you control model
|
||||
installation. Examples:
|
||||
|
||||
```
|
||||
# (list all controlnet models)
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
invokeai-model-install --add http://civitai.com/2860
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
```
|
||||
|
||||
!!! tip "Forward Slashes"
|
||||
On Windows systems, use forward slashes rather than backslashes
|
||||
in your file paths.
|
||||
If you do use backslashes,
|
||||
you must double them like this:
|
||||
`C:\\Users\\fred\\Downloads\\martians.safetensors`
|
||||
### Installation via the Web GUI
|
||||
|
||||
Alternatively you can directly import the file using its URL:
|
||||
To install a new model using the Web GUI, do the following:
|
||||
|
||||
```bash
|
||||
invoke> !import_model https://example.org/sd_models/martians.safetensors
|
||||
```
|
||||
1. Open the InvokeAI Model Manager (cube at the bottom of the
|
||||
left-hand panel) and navigate to *Import Models*
|
||||
|
||||
For this to work, the URL must not be password-protected. Otherwise
|
||||
you will receive a 404 error.
|
||||
2. In the field labeled *Location* type in the path to the model you
|
||||
wish to install. You may use a URL, HuggingFace repo id, or a path on
|
||||
your local disk.
|
||||
|
||||
When you import a legacy model, the CLI will first ask you what type
|
||||
of model this is. You can indicate whether it is a model based on
|
||||
Stable Diffusion 1.x (1.4 or 1.5), one based on Stable Diffusion 2.x,
|
||||
or a 1.x inpainting model. Be careful to indicate the correct model
|
||||
type, or it will not load correctly. You can correct the model type
|
||||
after the fact using the `!edit_model` command.
|
||||
3. Alternatively, the *Scan for Models* button allows you to paste in
|
||||
the path to a folder somewhere on your machine. It will be scanned for
|
||||
importable models and prompt you to add the ones of your choice.
|
||||
|
||||
The system will then ask you a few other questions about the model,
|
||||
including what size image it was trained on (usually 512x512), what
|
||||
name and description you wish to use for it, and whether you would
|
||||
like to install a custom VAE (variable autoencoder) file for the
|
||||
model. For recent models, the answer to the VAE question is usually
|
||||
"no," but it won't hurt to answer "yes".
|
||||
4. Press *Add Model* and wait for confirmation that the model
|
||||
was added.
|
||||
|
||||
After importing, the model will load. If this is successful, you will
|
||||
be asked if you want to keep the model loaded in memory to start
|
||||
generating immediately. You'll also be asked if you wish to make this
|
||||
the default model on startup. You can change this later using
|
||||
`!edit_model`.
|
||||
To delete a model, Select *Model Manager* to list all the currently
|
||||
installed models. Press the trash can icons to delete any models you
|
||||
wish to get rid of. Models whose weights are located inside the
|
||||
InvokeAI `models` directory will be purged from disk, while those
|
||||
located outside will be unregistered from InvokeAI, but not deleted.
|
||||
|
||||
#### Importing a batch of `.ckpt` and `.safetensors` models from a directory
|
||||
You can see where model weights are located by clicking on the model name.
|
||||
This will bring up an editable info panel showing the model's characteristics,
|
||||
including the `Model Location` of its files.
|
||||
|
||||
You may also point `!import_model` to a directory containing a set of
|
||||
`.ckpt` or `.safetensors` files. They will be imported _en masse_.
|
||||
### Installation via the `autoimport` function
|
||||
|
||||
!!! example
|
||||
In the InvokeAI root directory you will find a series of folders under
|
||||
`autoimport`, one each for main models, controlnets, embeddings and
|
||||
Loras. Any models that you add to these directories will be scanned
|
||||
at startup time and registered automatically.
|
||||
|
||||
```console
|
||||
invoke> !import_model C:/Users/fred/Downloads/civitai_models/
|
||||
```
|
||||
You may create symbolic links from these folders to models located
|
||||
elsewhere on disk and they will be autoimported. You can also create
|
||||
subfolders and organize them as you wish.
|
||||
|
||||
You will be given the option to import all models found in the
|
||||
directory, or select which ones to import. If there are subfolders
|
||||
within the directory, they will be searched for models to import.
|
||||
|
||||
#### Installing `diffusers` models
|
||||
|
||||
You can install a `diffusers` model from the HuggingFace site using
|
||||
`!import_model` and the HuggingFace repo_id for the model:
|
||||
|
||||
```bash
|
||||
invoke> !import_model andite/anything-v4.0
|
||||
```
|
||||
|
||||
Alternatively, you can download the model to disk and import it from
|
||||
there. The model may be distributed as a ZIP file, or as a Git
|
||||
repository:
|
||||
|
||||
```bash
|
||||
invoke> !import_model C:/Users/fred/Downloads/andite--anything-v4.0
|
||||
```
|
||||
|
||||
!!! tip "The CLI supports file path autocompletion"
|
||||
Type a bit of the path name and hit ++tab++ in order to get a choice of
|
||||
possible completions.
|
||||
|
||||
!!! tip "On Windows, you can drag model files onto the command-line"
|
||||
Once you have typed in `!import_model `, you can drag the
|
||||
model file or directory onto the command-line to insert the model path. This way, you don't need to
|
||||
type it or copy/paste. However, you will need to reverse or
|
||||
double backslashes as noted above.
|
||||
|
||||
Before installing, the CLI will ask you for a short name and
|
||||
description for the model, whether to make this the default model that
|
||||
is loaded at InvokeAI startup time, and whether to replace its
|
||||
VAE. Generally the answer to the latter question is "no".
|
||||
|
||||
### Converting legacy models into `diffusers`
|
||||
|
||||
The CLI `!convert_model` will convert a `.safetensors` or `.ckpt`
|
||||
models file into `diffusers` and install it.This will enable the model
|
||||
to load and run faster without loss of image quality.
|
||||
|
||||
The usage is identical to `!import_model`. You may point the command
|
||||
to either a downloaded model file on disk, or to a (non-password
|
||||
protected) URL:
|
||||
|
||||
```bash
|
||||
invoke> !convert_model C:/Users/fred/Downloads/martians.safetensors
|
||||
```
|
||||
|
||||
After a successful conversion, the CLI will offer you the option of
|
||||
deleting the original `.ckpt` or `.safetensors` file.
|
||||
|
||||
### Optimizing a previously-installed model
|
||||
|
||||
Lastly, if you have previously installed a `.ckpt` or `.safetensors`
|
||||
file and wish to convert it into a `diffusers` model, you can do this
|
||||
without re-downloading and converting the original file using the
|
||||
`!optimize_model` command. Simply pass the short name of an existing
|
||||
installed model:
|
||||
|
||||
```bash
|
||||
invoke> !optimize_model martians-v1.0
|
||||
```
|
||||
|
||||
The model will be converted into `diffusers` format and replace the
|
||||
previously installed version. You will again be offered the
|
||||
opportunity to delete the original `.ckpt` or `.safetensors` file.
|
||||
|
||||
### Related CLI Commands
|
||||
|
||||
There are a whole series of additional model management commands in
|
||||
the CLI that you can read about in [Command-Line
|
||||
Interface](../features/CLI.md). These include:
|
||||
|
||||
* `!models` - List all installed models
|
||||
* `!switch <model name>` - Switch to the indicated model
|
||||
* `!edit_model <model name>` - Edit the indicated model to change its name, description or other properties
|
||||
* `!del_model <model name>` - Delete the indicated model
|
||||
|
||||
### Manually editing `configs/models.yaml`
|
||||
|
||||
|
||||
If you are comfortable with a text editor then you may simply edit `models.yaml`
|
||||
directly.
|
||||
|
||||
You will need to download the desired `.ckpt/.safetensors` file and
|
||||
place it somewhere on your machine's filesystem. Alternatively, for a
|
||||
`diffusers` model, record the repo_id or download the whole model
|
||||
directory. Then using a **text** editor (e.g. the Windows Notepad
|
||||
application), open the file `configs/models.yaml`, and add a new
|
||||
stanza that follows this model:
|
||||
|
||||
#### A legacy model
|
||||
|
||||
A legacy `.ckpt` or `.safetensors` entry will look like this:
|
||||
|
||||
```yaml
|
||||
arabian-nights-1.0:
|
||||
description: A great fine-tune in Arabian Nights style
|
||||
weights: ./path/to/arabian-nights-1.0.ckpt
|
||||
config: ./configs/stable-diffusion/v1-inference.yaml
|
||||
format: ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
default: false
|
||||
```
|
||||
|
||||
Note that `format` is `ckpt` for both `.ckpt` and `.safetensors` files.
|
||||
|
||||
#### A diffusers model
|
||||
|
||||
A stanza for a `diffusers` model will look like this for a HuggingFace
|
||||
model with a repository ID:
|
||||
|
||||
```yaml
|
||||
arabian-nights-1.1:
|
||||
description: An even better fine-tune of the Arabian Nights
|
||||
repo_id: captahab/arabian-nights-1.1
|
||||
format: diffusers
|
||||
default: true
|
||||
```
|
||||
|
||||
And for a downloaded directory:
|
||||
|
||||
```yaml
|
||||
arabian-nights-1.1:
|
||||
description: An even better fine-tune of the Arabian Nights
|
||||
path: /path/to/captahab-arabian-nights-1.1
|
||||
format: diffusers
|
||||
default: true
|
||||
```
|
||||
|
||||
There is additional syntax for indicating an external VAE to use with
|
||||
this model. See `INITIAL_MODELS.yaml` and `models.yaml` for examples.
|
||||
|
||||
After you save the modified `models.yaml` file relaunch
|
||||
`invokeai`. The new model will now be available for your use.
|
||||
|
||||
### Installation via the WebUI
|
||||
|
||||
To access the WebUI Model Manager, click on the button that looks like
|
||||
a cube in the upper right side of the browser screen. This will bring
|
||||
up a dialogue that lists the models you have already installed, and
|
||||
allows you to load, delete or edit them:
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
To add a new model, click on **+ Add New** and select to either a
|
||||
checkpoint/safetensors model, or a diffusers model:
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
In this example, we chose **Add Diffusers**. As shown in the figure
|
||||
below, a new dialogue prompts you to enter the name to use for the
|
||||
model, its description, and either the location of the `diffusers`
|
||||
model on disk, or its Repo ID on the HuggingFace web site. If you
|
||||
choose to enter a path to disk, the system will autocomplete for you
|
||||
as you type:
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Press **Add Model** at the bottom of the dialogue (scrolled out of
|
||||
site in the figure), and the model will be downloaded, imported, and
|
||||
registered in `models.yaml`.
|
||||
|
||||
The **Add Checkpoint/Safetensor Model** option is similar, except that
|
||||
in this case you can choose to scan an entire folder for
|
||||
checkpoint/safetensors files to import. Simply type in the path of the
|
||||
directory and press the "Search" icon. This will display the
|
||||
`.ckpt` and `.safetensors` found inside the directory and its
|
||||
subfolders, and allow you to choose which ones to import:
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
## Model Management Startup Options
|
||||
|
||||
The `invoke` launcher and the `invokeai` script accept a series of
|
||||
command-line arguments that modify InvokeAI's behavior when loading
|
||||
models. These can be provided on the command line, or added to the
|
||||
InvokeAI root directory's `invokeai.init` initialization file.
|
||||
|
||||
The arguments are:
|
||||
|
||||
* `--model <model name>` -- Start up with the indicated model loaded
|
||||
* `--ckpt_convert` -- When a checkpoint/safetensors model is loaded, convert it into a `diffusers` model in memory. This does not permanently save the converted model to disk.
|
||||
* `--autoconvert <path/to/directory>` -- Scan the indicated directory path for new checkpoint/safetensors files, convert them into `diffusers` models, and import them into InvokeAI.
|
||||
|
||||
Here is an example of providing an argument on the command line using
|
||||
the `invoke.sh` launch script:
|
||||
|
||||
```bash
|
||||
invoke.sh --autoconvert /home/fred/stable-diffusion-checkpoints
|
||||
```
|
||||
|
||||
And here is what the same argument looks like in `invokeai.init`:
|
||||
|
||||
```bash
|
||||
--outdir="/home/fred/invokeai/outputs
|
||||
--no-nsfw_checker
|
||||
--autoconvert /home/fred/stable-diffusion-checkpoints
|
||||
```
|
||||
The location of the autoimport directories are controlled by settings
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
@@ -24,7 +24,8 @@ read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
|
||||
if ! git tag $VERSION ; then
|
||||
git push origin :refs/tags/$VERSION
|
||||
if ! git tag -fa $VERSION ; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
@@ -38,7 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
|
||||
echo.
|
||||
echo See %INSTRUCTIONS% for more details.
|
||||
echo.
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
|
||||
pause
|
||||
|
||||
@rem ---------------------------- check Python version ---------------
|
||||
|
||||
@@ -19,7 +19,7 @@ echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [2] "
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
|
||||
@@ -14,6 +14,7 @@ from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
|
||||
@@ -36,11 +37,16 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
if base_models and len(base_models)>0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@@ -63,20 +69,35 @@ async def update_model(
|
||||
) -> UpdateModelResponse:
|
||||
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = info.model_name,
|
||||
new_base = info.base_model,
|
||||
)
|
||||
logger.debug(f'renaming result = {result}')
|
||||
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get('path')
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
@@ -108,6 +129,7 @@ async def update_model(
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
415: {"description" : "Unrecognized file/folder format"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
@@ -134,7 +156,7 @@ async def import_model(
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@@ -147,6 +169,9 @@ async def import_model(
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@@ -4,17 +4,12 @@ from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from invokeai.app.models.image import ImageField
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
UIConfig,
|
||||
)
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext, UIConfig)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
@@ -32,7 +27,8 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
collection: list[float] = Field(
|
||||
default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
@@ -41,7 +37,8 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
collection: list[ImageField] = Field(
|
||||
default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
@@ -57,6 +54,14 @@ class RangeInvocation(BaseInvocation):
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Range",
|
||||
"tags": ["range", "integer", "collection"]
|
||||
},
|
||||
}
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
@@ -79,10 +84,20 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Sized Range",
|
||||
"tags": ["range", "integer", "size", "collection"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.start + self.size, self.step))
|
||||
)
|
||||
collection=list(
|
||||
range(
|
||||
self.start, self.start + self.size,
|
||||
self.step)))
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@@ -103,11 +118,21 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Range",
|
||||
"tags": ["range", "integer", "random", "collection"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
collection=list(
|
||||
rng.integers(
|
||||
low=self.low, high=self.high,
|
||||
size=self.size)))
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
@@ -121,6 +146,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
default=[], description="The image collection to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
@@ -128,6 +154,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
"images": "image_collection",
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||
|
||||
import torch
|
||||
from compel import Compel
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt, Fragment)
|
||||
@@ -14,6 +22,7 @@ from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
@@ -23,6 +32,34 @@ class ConditioningField(BaseModel):
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
#type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
ConditioningInfoType = Annotated[
|
||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
Field(discriminator="type")
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
#unconditioned: Optional[torch.Tensor]
|
||||
|
||||
#class ConditioningAlgo(str, Enum):
|
||||
# Compose = "compose"
|
||||
# ComposeEx = "compose_ex"
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
@@ -57,10 +94,10 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.dict(), context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
@@ -82,6 +119,7 @@ class CompelInvocation(BaseInvocation):
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
@@ -100,7 +138,7 @@ class CompelInvocation(BaseInvocation):
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
truncate_long_prompts=True,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -118,10 +156,19 @@ class CompelInvocation(BaseInvocation):
|
||||
cross_attention_control_args=options.get(
|
||||
"cross_attention_control", None),)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
BasicConditioningInfo(
|
||||
embeds=c,
|
||||
extra_conditioning=ec,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
@@ -129,6 +176,397 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
if get_pooled:
|
||||
c_pooled = prompt_embeds[0]
|
||||
else:
|
||||
c_pooled = None
|
||||
c = prompt_embeds.hidden_states[-2]
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=True,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
# TODO: better logging for and syntax
|
||||
for prompt_obj in conjunction.prompts:
|
||||
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
|
||||
|
||||
# TODO: ask for optimizations? to not run text_encoder twice
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
if get_pooled:
|
||||
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, ec
|
||||
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||
|
||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
@@ -141,6 +579,14 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "CLIP Skip",
|
||||
"tags": ["clip", "skip"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
|
||||
@@ -1,43 +1,25 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float, bool
|
||||
from builtins import bool, float
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List, Dict
|
||||
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
|
||||
LeresDetector, LineartAnimeDetector,
|
||||
LineartDetector, MediapipeFaceDetector,
|
||||
MidasDetector, MLSDdetector, NormalBaeDetector,
|
||||
OpenposeDetector, PidiNetDetector, SamDetector,
|
||||
ZoeDetector)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
LineartDetector,
|
||||
LineartAnimeDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
ContentShuffleDetector,
|
||||
ZoeDetector,
|
||||
MediapipeFaceDetector,
|
||||
SamDetector,
|
||||
LeresDetector,
|
||||
)
|
||||
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
@@ -75,33 +57,34 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
# crop and fill options not ready yet
|
||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||
|
||||
@@ -112,16 +95,22 @@ class ControlNetModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
control_model: Optional[ControlNetModelField] = Field(
|
||||
default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
control_weight: Union[float, List[float]] = Field(
|
||||
default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||
default="balanced", description="The control mode to use")
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
@@ -130,11 +119,13 @@ class ControlField(BaseModel):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
raise ValueError(
|
||||
'Control weights must be within -1 to 2 range')
|
||||
else:
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
@@ -175,13 +166,14 @@ class ControlNetInvocation(BaseInvocation):
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -208,6 +200,13 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Processor",
|
||||
"tags": ["image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
@@ -239,14 +238,15 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width = image_dto.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height = image_dto.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class CannyImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
@@ -255,13 +255,23 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Canny Processor",
|
||||
"tags": ["controlnet", "canny", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
processed_image = canny_processor(
|
||||
image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class HedImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
@@ -273,6 +283,14 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Softedge(HED) Processor",
|
||||
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
@@ -285,7 +303,8 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class LineartImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
@@ -295,16 +314,25 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Processor",
|
||||
"tags": ["controlnet", "lineart", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse)
|
||||
lineart_processor = LineartDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class LineartAnimeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
@@ -313,8 +341,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processor = LineartAnimeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
@@ -322,7 +359,8 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class OpenposeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
@@ -332,17 +370,26 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Openpose Processor",
|
||||
"tags": ["controlnet", "openpose", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
openpose_processor = OpenposeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MidasDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
@@ -353,6 +400,14 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Midas (Depth) Processor",
|
||||
"tags": ["controlnet", "midas", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
@@ -364,7 +419,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class NormalbaeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
@@ -373,15 +429,25 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Normal BAE Processor",
|
||||
"tags": ["controlnet", "normal", "bae", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MlsdImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
@@ -392,17 +458,25 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "MLSD Processor",
|
||||
"tags": ["controlnet", "mlsd", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
processed_image = mlsd_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class PidiImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
@@ -413,17 +487,26 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "PIDI Processor",
|
||||
"tags": ["controlnet", "pidi", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
pidi_processor = PidiNetDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class ContentShuffleImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
@@ -435,6 +518,14 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
@@ -448,19 +539,30 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class ZoeDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Zoe (Depth) Processor",
|
||||
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MediapipeFaceProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
@@ -469,16 +571,27 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mediapipe Processor",
|
||||
"tags": ["controlnet", "mediapipe", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
# so convert to RGB if needed
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
processed_image = mediapipe_face_processor(
|
||||
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
|
||||
class LeresImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies leres processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
@@ -490,18 +603,25 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Leres (Depth) Processor",
|
||||
"tags": ["controlnet", "leres", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
processed_image = leres_processor(
|
||||
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class TileResamplerProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
|
||||
# fmt: off
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
@@ -510,6 +630,14 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(self,
|
||||
np_img: np.ndarray,
|
||||
@@ -528,28 +656,33 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
def run_processor(self, img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(np_img,
|
||||
#res=self.tile_size,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
||||
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
class SegmentAnythingProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies segment anything processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
|
||||
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(np_img)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
@@ -561,7 +694,8 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
h, w = anns[0]['segmentation'].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
final_img = Image.fromarray(
|
||||
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann['segmentation']
|
||||
@@ -569,5 +703,8 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
final_img.paste(
|
||||
Image.fromarray(img, mode="RGB"),
|
||||
(0, 0),
|
||||
Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
@@ -35,6 +35,14 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "OpenCV Inpaint",
|
||||
"tags": ["opencv", "inpaint"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
@@ -130,6 +130,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"title": "Inpaint"
|
||||
},
|
||||
}
|
||||
|
||||
@@ -146,9 +147,13 @@ class InpaintInvocation(BaseInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
def get_conditioning(self, context, unet):
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@@ -157,13 +162,13 @@ class InpaintInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context,)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||
|
||||
with vae_info as vae,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
@@ -209,7 +214,6 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@@ -217,6 +221,8 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
conditioning = self.get_conditioning(context, model.context.model.unet)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -71,6 +71,15 @@ class LoadImageInvocation(BaseInvocation):
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Load Image",
|
||||
"tags": ["image", "load"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -91,6 +100,14 @@ class ShowImageInvocation(BaseInvocation):
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Show Image",
|
||||
"tags": ["image", "show"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
if image:
|
||||
@@ -119,6 +136,14 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Crop Image",
|
||||
"tags": ["image", "crop"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -157,6 +182,14 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Paste Image",
|
||||
"tags": ["image", "paste"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -207,6 +240,14 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mask From Alpha",
|
||||
"tags": ["image", "mask", "alpha"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -241,6 +282,14 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply Images",
|
||||
"tags": ["image", "multiply"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
||||
@@ -277,6 +326,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Channel",
|
||||
"tags": ["image", "channel"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -312,6 +369,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Convert Image",
|
||||
"tags": ["image", "convert"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -345,6 +410,14 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Blur Image",
|
||||
"tags": ["image", "blur"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -404,6 +477,14 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Resize Image",
|
||||
"tags": ["image", "resize"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -437,11 +518,19 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Scale Image",
|
||||
"tags": ["image", "scale"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -482,6 +571,14 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "lerp"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -518,6 +615,14 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
@@ -133,6 +134,14 @@ class InfillColorInvocation(BaseInvocation):
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Color Infill",
|
||||
"tags": ["image", "inpaint", "color", "infill"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -173,6 +182,14 @@ class InfillTileInvocation(BaseInvocation):
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Infill",
|
||||
"tags": ["image", "inpaint", "tile", "infill"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -206,6 +223,14 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Patch Match Infill",
|
||||
"tags": ["image", "inpaint", "patchmatch", "infill"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
@@ -31,6 +32,13 @@ from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
@@ -76,7 +84,7 @@ def get_scheduler(
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict()
|
||||
**scheduler_info.dict(), context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
@@ -132,6 +140,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
@@ -160,13 +169,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(
|
||||
self.positive_conditioning.conditioning_name
|
||||
)
|
||||
uc, _ = context.services.latents.get(
|
||||
self.negative_conditioning.conditioning_name
|
||||
)
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@@ -188,7 +198,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0, # ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@@ -262,6 +272,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -313,19 +324,21 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"})
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@@ -333,7 +346,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
@@ -354,6 +367,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@@ -377,6 +391,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latent To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
@@ -403,19 +418,22 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"})
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@@ -423,7 +441,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
@@ -455,6 +473,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@@ -475,13 +494,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(False, description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
@@ -491,10 +511,36 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.dict(), context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae:
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
vae.post_quant_conv.to(latents.dtype)
|
||||
vae.decoder.conv_in.to(latents.dtype)
|
||||
vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
else:
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
@@ -553,17 +599,29 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Resize Latents",
|
||||
"tags": ["latents", "resize"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
latents.to(device), size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
@@ -587,18 +645,30 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Scale Latents",
|
||||
"tags": ["latents", "scale"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
@@ -618,12 +688,15 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(False, description="Decode in full precision")
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"title": "Image To Latents",
|
||||
"tags": ["latents", "image"]
|
||||
},
|
||||
}
|
||||
|
||||
@@ -636,7 +709,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.dict(), context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@@ -644,6 +717,32 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
with vae_info as vae:
|
||||
orig_dtype = vae.dtype
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
vae.post_quant_conv.to(orig_dtype)
|
||||
vae.decoder.conv_in.to(orig_dtype)
|
||||
vae.decoder.mid_block.to(orig_dtype)
|
||||
#else:
|
||||
# latents = latents.float()
|
||||
|
||||
else:
|
||||
vae.to(dtype=torch.float16)
|
||||
#latents = latents.half()
|
||||
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
@@ -658,8 +757,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -52,6 +52,14 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Add",
|
||||
"tags": ["math", "add"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
@@ -65,6 +73,14 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Subtract",
|
||||
"tags": ["math", "subtract"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
@@ -78,6 +94,14 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply",
|
||||
"tags": ["math", "multiply"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
@@ -91,6 +115,14 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Divide",
|
||||
"tags": ["math", "divide"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
@@ -105,5 +137,14 @@ class RandomIntInvocation(BaseInvocation):
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Integer",
|
||||
"tags": ["math", "random", "integer"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
BaseInvocationOutput, InvocationConfig,
|
||||
InvocationContext)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
||||
@@ -97,6 +97,14 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
@@ -33,7 +33,6 @@ class ClipField(BaseModel):
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
@@ -50,12 +49,12 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
@@ -64,7 +63,6 @@ class LoRAModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@@ -157,6 +155,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
@@ -167,7 +181,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
@@ -208,6 +222,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
# TODO: ui rewrite
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
|
||||
@@ -48,7 +48,7 @@ def get_noise(
|
||||
dtype=torch_dtype(device),
|
||||
device=noise_device_type,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
).to("cpu")
|
||||
|
||||
return noise_tensor
|
||||
|
||||
@@ -112,6 +112,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
591
invokeai/app/invocations/onnx.py
Normal file
591
invokeai/app/invocations/onnx.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.util import choose_torch_device
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from tqdm import tqdm
|
||||
from .model import ClipField
|
||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||
from .compel import CompelOutput
|
||||
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
#stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
#)
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
text_encoder.release_session()
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(np.float16)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
#latents_dtype = prompt_embeds.dtype
|
||||
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
#if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
unet.create_session()
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
import time
|
||||
times = []
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
start_time = time.time()
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
times.append(time.time() - start_time)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id= "test",
|
||||
step=i,
|
||||
timestep=timestep,
|
||||
latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(
|
||||
self,
|
||||
context=context,
|
||||
source_node_id=source_node_id,
|
||||
intermediate_state=state
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
print(times)
|
||||
unet.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
# Latent to image
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
vae.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
#fmt: on
|
||||
|
||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
|
||||
model_name = "stable-diffusion-v1-5"
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {model_name}!")
|
||||
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -43,6 +43,14 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Linear Range (Float)",
|
||||
"tags": ["math", "float", "linear", "range"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
@@ -113,6 +121,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Param Easing By Step",
|
||||
"tags": ["param", "step", "easing"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput, FloatOutput
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
@@ -14,6 +17,14 @@ class ParamIntInvocation(BaseInvocation):
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "integer"],
|
||||
"title": "Integer Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
@@ -24,5 +35,36 @@ class ParamFloatInvocation(BaseInvocation):
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "float"],
|
||||
"title": "Float Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""A string output"""
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = Field(default=None, description="The output string")
|
||||
|
||||
|
||||
class ParamStringInvocation(BaseInvocation):
|
||||
"""A string parameter"""
|
||||
type: Literal['param_string'] = 'param_string'
|
||||
text: str = Field(default='', description='The string value')
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "string"],
|
||||
"title": "String Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Literal, Optional
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
@@ -48,6 +48,14 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
default=False, description="Whether to use the combinatorial generator"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Dynamic Prompt",
|
||||
"tags": ["prompt", "dynamic"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
@@ -72,6 +80,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompts From File",
|
||||
"tags": ["prompt", "file"]
|
||||
},
|
||||
}
|
||||
|
||||
@validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
if not exists(v):
|
||||
|
||||
662
invokeai/app/invocations/sdxl.py
Normal file
662
invokeai/app/invocations/sdxl.py
Normal file
@@ -0,0 +1,662 @@
|
||||
import torch
|
||||
import inspect
|
||||
from tqdm import tqdm
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .compel import ConditioningField
|
||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_sdxl"] = "t2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["l2l_sdxl"] = "l2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
# apply denoising_start
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
# apply scheduler extra args
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path, PosixPath
|
||||
from typing import Literal, Union, cast
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
@@ -11,27 +11,36 @@ from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
REALESRGAN_MODELS = Literal[
|
||||
ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x4plus.pth",
|
||||
"RealESRGAN_x4plus_anime_6B.pth",
|
||||
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
|
||||
class RealESRGANInvocation(BaseInvocation):
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
type: Literal["realesrgan"] = "realesrgan"
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: REALESRGAN_MODELS = Field(
|
||||
model_name: ESRGAN_MODELS = Field(
|
||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Upscale (RealESRGAN)",
|
||||
"tags": ["image", "upscale", "realesrgan"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
models_path = context.services.configuration.models_path
|
||||
@@ -65,19 +74,17 @@ class RealESRGANInvocation(BaseInvocation):
|
||||
scale=4,
|
||||
)
|
||||
netscale = 4
|
||||
# TODO: add x2 models handling?
|
||||
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
||||
# # x2 RRDBNet model
|
||||
# model = RRDBNet(
|
||||
# num_in_ch=3,
|
||||
# num_out_ch=3,
|
||||
# num_feat=64,
|
||||
# num_block=23,
|
||||
# num_grow_ch=32,
|
||||
# scale=2,
|
||||
# )
|
||||
# model_path = Path()
|
||||
# netscale = 2
|
||||
elif self.model_name in ["RealESRGAN_x2plus.pth"]:
|
||||
# x2 RRDBNet model
|
||||
rrdbnet_model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
netscale = 2
|
||||
else:
|
||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||
context.services.logger.error(msg)
|
||||
|
||||
@@ -105,8 +105,6 @@ class EventServiceBase:
|
||||
def emit_model_load_started (
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
@@ -117,8 +115,6 @@ class EventServiceBase:
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -129,8 +125,6 @@ class EventServiceBase:
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
@@ -142,12 +136,12 @@ class EventServiceBase:
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
hash=model_info.hash,
|
||||
location=model_info.location,
|
||||
precision=str(model_info.precision),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -339,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
@@ -347,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# if we are called from within a node, then we get to emit
|
||||
# load start and complete events
|
||||
if node and context:
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -366,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel,
|
||||
)
|
||||
|
||||
if node and context:
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -510,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -536,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
||||
@@ -466,7 +466,6 @@ class Generator:
|
||||
dtype=samples.dtype,
|
||||
device=samples.device,
|
||||
)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
|
||||
@@ -69,7 +69,6 @@ transformers.logging.set_verbosity_error()
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
@@ -223,7 +222,7 @@ def download_conversion_models():
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
logger.info("Installing RealESRGAN models...")
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
@@ -240,6 +239,11 @@ def download_realesrgan():
|
||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description = "RealESRGAN_x2plus.pth",
|
||||
),
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||
@@ -629,7 +633,7 @@ def run_console_ui(
|
||||
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
@@ -706,7 +710,7 @@ def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
||||
old_init_file = root / 'invokeai.init'
|
||||
new_init_file = root / 'invokeai.yaml'
|
||||
old_hub = root / 'models/hub'
|
||||
migration_needed = old_init_file.exists() and not new_init_file.exists() or old_hub.exists()
|
||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
|
||||
if migration_needed:
|
||||
if opt.yes_to_all or \
|
||||
|
||||
@@ -10,7 +10,7 @@ from tempfile import TemporaryDirectory
|
||||
from typing import List, Dict, Callable, Union, Set
|
||||
|
||||
import requests
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
@@ -212,7 +212,7 @@ class ModelInstall(object):
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
@@ -310,6 +310,8 @@ class ModelInstall(object):
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
elif location.is_dir():
|
||||
return location.name
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
@@ -365,7 +367,7 @@ class ModelInstall(object):
|
||||
model = None
|
||||
for revision in revisions:
|
||||
try:
|
||||
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
pass
|
||||
if model:
|
||||
|
||||
@@ -3,6 +3,7 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
||||
@@ -6,11 +6,22 @@ from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: rename and split this file
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
#alpha: Optional[float]
|
||||
@@ -708,3 +719,185 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = dict()
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(
|
||||
np.copy(orig_embeddings),
|
||||
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
||||
@@ -104,7 +104,8 @@ class ModelCache(object):
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: float=max_cache_size
|
||||
self.max_vram_cache_size: float=max_vram_cache_size
|
||||
@@ -327,6 +328,25 @@ class ModelCache(object):
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
#break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||
|
||||
@@ -362,6 +382,9 @@ class ModelCache(object):
|
||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||
vram_in_use += mem.vram_used # note vram_used is negative
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
||||
@@ -106,16 +106,16 @@ providing information about a model defined in models.yaml. For example:
|
||||
|
||||
>>> models = mgr.list_models()
|
||||
>>> json.dumps(models[0])
|
||||
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||
"model_format": "diffusers",
|
||||
"name": "canny",
|
||||
"base_model": "sd-1",
|
||||
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||
"model_format": "diffusers",
|
||||
"name": "canny",
|
||||
"base_model": "sd-1",
|
||||
"type": "controlnet"
|
||||
}
|
||||
|
||||
You can filter by model type and base model as shown here:
|
||||
|
||||
|
||||
|
||||
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
|
||||
base_model=BaseModelType.StableDiffusion1)
|
||||
for c in controlnets:
|
||||
@@ -140,14 +140,14 @@ Layout of the `models` directory:
|
||||
|
||||
models
|
||||
├── sd-1
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
├── sd-2
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
└── core
|
||||
├── face_reconstruction
|
||||
@@ -195,7 +195,7 @@ name, base model, type and a dict of model attributes. See
|
||||
`invokeai/backend/model_management/models` for the attributes required
|
||||
by each model type.
|
||||
|
||||
A model can be deleted using `del_model()`, providing the same
|
||||
A model can be deleted using `del_model()`, providing the same
|
||||
identifying information as `get_model()`
|
||||
|
||||
The `heuristic_import()` method will take a set of strings
|
||||
@@ -304,7 +304,7 @@ class ModelManager(object):
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
@@ -323,7 +323,7 @@ class ModelManager(object):
|
||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
@@ -431,7 +431,7 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@@ -456,7 +456,7 @@ class ModelManager(object):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
@@ -489,7 +489,7 @@ class ModelManager(object):
|
||||
self.cache_keys[model_key].add(model_context.key)
|
||||
|
||||
model_hash = "<NO_HASH>" # TODO:
|
||||
|
||||
|
||||
return ModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
@@ -518,7 +518,7 @@ class ModelManager(object):
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||
known to the configuration.
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
@@ -568,6 +568,9 @@ class ModelManager(object):
|
||||
model_type=cur_model_type,
|
||||
)
|
||||
|
||||
# expose paths as absolute to help web UI
|
||||
if path := model_dict.get('path'):
|
||||
model_dict['path'] = str(self.app_config.root_path / path)
|
||||
models.append(model_dict)
|
||||
|
||||
return models
|
||||
@@ -635,6 +638,10 @@ class ModelManager(object):
|
||||
The returned dict has the same format as the dict returned by
|
||||
model_info().
|
||||
"""
|
||||
# relativize paths as they go in - this makes it easier to move the root directory around
|
||||
if path := model_attributes.get('path'):
|
||||
if Path(path).is_relative_to(self.app_config.root_path):
|
||||
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
@@ -685,12 +692,12 @@ class ModelManager(object):
|
||||
if new_name is None and new_base is None:
|
||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||
return
|
||||
|
||||
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.get(model_key, None)
|
||||
if not model_cfg:
|
||||
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
||||
|
||||
|
||||
old_path = self.app_config.root_path / model_cfg.path
|
||||
new_name = new_name or model_name
|
||||
new_base = new_base or base_model
|
||||
@@ -700,7 +707,7 @@ class ModelManager(object):
|
||||
|
||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||
if old_path.is_relative_to(self.app_config.models_path):
|
||||
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
|
||||
move(old_path, new_path)
|
||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||
|
||||
@@ -719,7 +726,7 @@ class ModelManager(object):
|
||||
self.models.pop(model_key, None) # delete
|
||||
self.models[new_key] = model_cfg
|
||||
self.commit()
|
||||
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -769,12 +776,12 @@ class ModelManager(object):
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@@ -817,10 +824,14 @@ class ModelManager(object):
|
||||
assert config_file_path is not None,'no config file path to write to'
|
||||
config_file_path = self.app_config.root_path / config_file_path
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile, config_file_path)
|
||||
try:
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile, config_file_path)
|
||||
except OSError as err:
|
||||
self.logger.warning(f"Could not modify the config file at {config_file_path}")
|
||||
self.logger.warning(err)
|
||||
|
||||
def preamble(self) -> str:
|
||||
"""
|
||||
@@ -970,13 +981,12 @@ class ModelManager(object):
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
successfully_installed = dict()
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
|
||||
from .models import (
|
||||
BaseModelType, ModelType, ModelVariantType,
|
||||
SchedulerPredictionType, SilenceWarnings,
|
||||
InvalidModelException
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
|
||||
@@ -22,7 +23,7 @@ class ModelProbeInfo(object):
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal['diffusers','checkpoint', 'lycoris']
|
||||
format: Literal['diffusers','checkpoint', 'lycoris', 'olive']
|
||||
image_size: int
|
||||
|
||||
class ProbeBase(object):
|
||||
@@ -38,6 +39,8 @@ class ModelProbe(object):
|
||||
|
||||
CLASS2TYPE = {
|
||||
'StableDiffusionPipeline' : ModelType.Main,
|
||||
'StableDiffusionXLPipeline' : ModelType.Main,
|
||||
'StableDiffusionXLImg2ImgPipeline' : ModelType.Main,
|
||||
'AutoencoderKL' : ModelType.Vae,
|
||||
'ControlNetModel' : ModelType.ControlNet,
|
||||
}
|
||||
@@ -59,7 +62,7 @@ class ModelProbe(object):
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise ValueError("model parameter {model} is neither a Path, nor a model")
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
@@ -99,9 +102,10 @@ class ModelProbe(object):
|
||||
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction),
|
||||
format = format,
|
||||
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||
) else 512,
|
||||
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
|
||||
768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction ) else \
|
||||
512
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
@@ -138,7 +142,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@@ -168,7 +172,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError(f"Unable to determine model type for {folder_path}")
|
||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
@@ -237,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@@ -248,7 +252,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise ValueError("Cannot determine base type")
|
||||
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
@@ -329,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
@@ -360,8 +367,12 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf['cross_attention_dim'] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf['cross_attention_dim'] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf['cross_attention_dim'] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
if self.model:
|
||||
@@ -418,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'config.json'
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Cannot determine base type for {self.folder_path}")
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file,'r') as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
@@ -435,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise ValueError('Unknown LoRA format encountered')
|
||||
raise InvalidModelException('Unknown LoRA format encountered')
|
||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||
|
||||
############## register probe classes ######
|
||||
|
||||
@@ -4,13 +4,17 @@ from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||
ModelType.Main: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@@ -18,12 +22,31 @@ MODEL_CLASSES = {
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
# will not work until support written
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
# will not work until support written
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
|
||||
@@ -8,13 +8,19 @@ from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
@@ -24,9 +30,12 @@ class ModelNotFoundException(Exception):
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
@@ -36,8 +45,12 @@ class ModelType(str, Enum):
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
@@ -250,16 +263,18 @@ class DiffusersModel(ModelBase):
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
os.path.join(self.model_path, child_type.value),
|
||||
#subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
@@ -426,3 +441,188 @@ class SilenceWarnings(object):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.model.data[key].numpy()
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray):
|
||||
new_node = numpy_helper.from_array(value)
|
||||
# set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
# new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.model.data
|
||||
|
||||
def items(self):
|
||||
raise NotImplementedError("tensor.items")
|
||||
#return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.model.data.keys()
|
||||
|
||||
def values(self):
|
||||
raise NotImplementedError("tensor.values")
|
||||
#return [obj for obj in self.raw_proto]
|
||||
|
||||
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.indexes
|
||||
|
||||
def items(self):
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider or "CPUExecutionProvider"
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
self.data = dict()
|
||||
for tensor in self.proto.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
self.data[name] = orv
|
||||
# set_external_data(tensor, location="in-memory-location")
|
||||
tensor.name = name
|
||||
# tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self):
|
||||
if self.session is None:
|
||||
#onnx.save(self.proto, "tmp.onnx")
|
||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
#self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
# sess.enable_profiling = True
|
||||
|
||||
# sess.intra_op_num_threads = 1
|
||||
# sess.inter_op_num_threads = 1
|
||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.enable_cpu_mem_arena = True
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
|
||||
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
output_names = self.session.get_outputs()
|
||||
for k in inputs:
|
||||
self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
for name in output_names:
|
||||
self.io_binding.bind_output(name.name)
|
||||
self.session.run_with_iobinding(self.io_binding, None)
|
||||
return self.io_binding.copy_outputs_to_cpu()
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
subfolder: Union[str, Path] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
|
||||
|
||||
114
invokeai/backend/model_management/models/sdxl.py
Normal file
114
invokeai/backend/model_management/models/sdxl.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
from .base import (
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusionXLModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
||||
assert model_type == ModelType.Main
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusionXL,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
# TO DO: implement picking
|
||||
pass
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusionXLModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusionXLModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
|
||||
else:
|
||||
return model_path
|
||||
@@ -5,14 +5,11 @@ from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
@@ -248,6 +245,12 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
},
|
||||
# note that these .yaml files don't yet exist!
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "xl-inference-v.yaml",
|
||||
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "xl-midas-inference.yaml",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,6 +266,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
|
||||
|
||||
# TODO: rework
|
||||
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
@@ -16,6 +16,7 @@ from .base import (
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
@@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
**kwargs,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**kwargs,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
@@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
|
||||
callback=callback,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
@@ -505,42 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
||||
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
||||
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
||||
)
|
||||
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
||||
)
|
||||
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
@@ -580,8 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
@@ -623,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@@ -638,8 +597,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
# TODO: rewrite to pass with conditionings
|
||||
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
@@ -669,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
encoder_hidden_states = conditioning_data.text_embeddings
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings])
|
||||
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
|
||||
@@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
unconditioning, conditioning
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
x_twice, sigma_twice, both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@@ -260,8 +297,32 @@ class InvokeAIDiffuserComponent:
|
||||
**kwargs,
|
||||
):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, unconditioning,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, conditioning,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
@@ -295,6 +356,20 @@ class InvokeAIDiffuserComponent:
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
@@ -307,6 +382,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -319,6 +396,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@@ -16,6 +16,14 @@ sd-2/main/stable-diffusion-2-inpainting:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||
recommended: False
|
||||
sdxl/main/stable-diffusion-xl-base-0-9:
|
||||
description: Stable Diffusion XL base model (12 GB; access token required)
|
||||
repo_id: stabilityai/stable-diffusion-xl-base-0.9
|
||||
recommended: False
|
||||
sdxl-refiner/main/stable-diffusion-xl-refiner-0-9:
|
||||
description: Stable Diffusion XL refiner model (12 GB; access token required)
|
||||
repo_id: stabilityai/stable-diffusion-xl-refiner-0.9
|
||||
recommended: False
|
||||
sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
repo_id: wavymulder/Analog-Diffusion
|
||||
@@ -96,5 +104,6 @@ sd-1/embedding/ahx-beta-453407d:
|
||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||
sd-1/lora/LowRA:
|
||||
path: https://civitai.com/api/download/models/63006
|
||||
recommended: True
|
||||
sd-1/lora/Ink scenery:
|
||||
path: https://civitai.com/api/download/models/83390
|
||||
|
||||
@@ -701,7 +701,7 @@ def select_and_download_models(opt: Namespace):
|
||||
|
||||
# the third argument is needed in the Windows 11 environment in
|
||||
# order to launch and resize a console window running this program
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-model-install')
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
installApp = AddModelApplication(opt)
|
||||
try:
|
||||
installApp.run()
|
||||
|
||||
@@ -17,28 +17,20 @@ from shutil import get_terminal_size
|
||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 130
|
||||
MIN_COLS = 136
|
||||
MIN_LINES = 45
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
ts = get_terminal_size()
|
||||
width = max(columns,ts.columns)
|
||||
height = max(lines,ts.lines)
|
||||
|
||||
OS = platform.uname().system
|
||||
if OS == "Windows":
|
||||
# The new Windows Terminal doesn't resize, so we relaunch in a CMD window.
|
||||
# Would prefer to use execvpe() here, but somehow it is not working properly
|
||||
# in the Windows 10 environment.
|
||||
if 'IA_RELAUNCHED' not in os.environ:
|
||||
args=['conhost']
|
||||
args.extend([launch_command] if launch_command else [sys.argv[0]])
|
||||
args.extend(sys.argv[1:])
|
||||
os.environ['IA_RELAUNCHED'] = 'True'
|
||||
os.execvp('conhost',args)
|
||||
else:
|
||||
_set_terminal_size_powershell(width,height)
|
||||
pass
|
||||
# not working reliably - ask user to adjust the window
|
||||
#_set_terminal_size_powershell(width,height)
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
_set_terminal_size_unix(width,height)
|
||||
|
||||
@@ -84,20 +76,14 @@ def _set_terminal_size_unix(width: int, height: int):
|
||||
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
||||
sys.stdout.flush()
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=None):
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
if term_cols >= min_cols and term_lines >= min_lines:
|
||||
return
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
set_terminal_size(cols, lines, launch_command)
|
||||
|
||||
# did it work?
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
if term_cols < cols or term_lines < lines:
|
||||
print(f'This window is too small for optimal display. For best results please enlarge it.')
|
||||
input('After resizing, press any key to continue...')
|
||||
set_terminal_size(cols, lines)
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
|
||||
169
invokeai/frontend/web/dist/assets/App-3986879c.js
vendored
169
invokeai/frontend/web/dist/assets/App-3986879c.js
vendored
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-879ff07f.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-879ff07f.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-81517a17.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-81517a17.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-49b5c7c4.js
vendored
Normal file
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-49b5c7c4.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-ba194473.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-ba194473.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-f1a5f9cf.js
vendored
125
invokeai/frontend/web/dist/assets/index-f1a5f9cf.js
vendored
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-8888b06f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-ba194473.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
21
invokeai/frontend/web/dist/locales/en.json
vendored
21
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -399,6 +399,8 @@
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||
"modelDeleted": "Model Deleted",
|
||||
"modelDeleteFailed": "Failed to delete model",
|
||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||
@@ -408,11 +410,13 @@
|
||||
"convertToDiffusers": "Convert To Diffusers",
|
||||
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
|
||||
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
|
||||
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
"convertToDiffusersSaveLocation": "Save Location",
|
||||
"noCustomLocationProvided": "No Custom Location Provided",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
"v1": "v1",
|
||||
"v2_base": "v2 (512px)",
|
||||
"v2_768": "v2 (768px)",
|
||||
@@ -450,7 +454,8 @@
|
||||
"none": "none",
|
||||
"addDifference": "Add Difference",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"selectModel": "Select Model"
|
||||
"selectModel": "Select Model",
|
||||
"importModels": "Import Models"
|
||||
},
|
||||
"parameters": {
|
||||
"general": "General",
|
||||
@@ -572,6 +577,7 @@
|
||||
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||
"downloadImageStarted": "Image Download Started",
|
||||
"imageCopied": "Image Copied",
|
||||
"problemCopyingImage": "Unable to Copy Image",
|
||||
"imageLinkCopied": "Image Link Copied",
|
||||
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||
"imageNotLoaded": "No Image Loaded",
|
||||
@@ -688,6 +694,15 @@
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes"
|
||||
"clearNodes": "Clear Nodes",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
"hideGraphNodes": "Hide Graph Overlay",
|
||||
"showGraphNodes": "Show Graph Overlay",
|
||||
"hideLegendNodes": "Hide Field Type Legend",
|
||||
"showLegendNodes": "Show Field Type Legend",
|
||||
"hideMinimapnodes": "Hide MiniMap",
|
||||
"showMinimapnodes": "Show MiniMap"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,6 +399,8 @@
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||
"modelDeleted": "Model Deleted",
|
||||
"modelDeleteFailed": "Failed to delete model",
|
||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||
@@ -408,11 +410,13 @@
|
||||
"convertToDiffusers": "Convert To Diffusers",
|
||||
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
|
||||
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
|
||||
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
"convertToDiffusersSaveLocation": "Save Location",
|
||||
"noCustomLocationProvided": "No Custom Location Provided",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
"v1": "v1",
|
||||
"v2_base": "v2 (512px)",
|
||||
"v2_768": "v2 (768px)",
|
||||
@@ -450,7 +454,8 @@
|
||||
"none": "none",
|
||||
"addDifference": "Add Difference",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"selectModel": "Select Model"
|
||||
"selectModel": "Select Model",
|
||||
"importModels": "Import Models"
|
||||
},
|
||||
"parameters": {
|
||||
"general": "General",
|
||||
@@ -572,6 +577,7 @@
|
||||
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||
"downloadImageStarted": "Image Download Started",
|
||||
"imageCopied": "Image Copied",
|
||||
"problemCopyingImage": "Unable to Copy Image",
|
||||
"imageLinkCopied": "Image Link Copied",
|
||||
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||
"imageNotLoaded": "No Image Loaded",
|
||||
@@ -688,6 +694,15 @@
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes"
|
||||
"clearNodes": "Clear Nodes",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
"hideGraphNodes": "Hide Graph Overlay",
|
||||
"showGraphNodes": "Show Graph Overlay",
|
||||
"hideLegendNodes": "Hide Field Type Legend",
|
||||
"showLegendNodes": "Show Field Type Legend",
|
||||
"hideMinimapnodes": "Hide MiniMap",
|
||||
"showMinimapnodes": "Show MiniMap"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
toggleGalleryPanel,
|
||||
@@ -14,10 +16,11 @@ import React, { memo } from 'react';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
const globalHotkeysSelector = createSelector(
|
||||
(state: RootState) => state.hotkeys,
|
||||
(hotkeys) => {
|
||||
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
|
||||
(hotkeys, ui) => {
|
||||
const { shift } = hotkeys;
|
||||
return { shift };
|
||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||
return { shift, shouldPinGallery, shouldPinParametersPanel };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@@ -34,7 +37,10 @@ const globalHotkeysSelector = createSelector(
|
||||
*/
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
||||
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
|
||||
globalHotkeysSelector
|
||||
);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
useHotkeys(
|
||||
'*',
|
||||
@@ -51,18 +57,30 @@ const GlobalHotkeys: React.FC = () => {
|
||||
|
||||
useHotkeys('o', () => {
|
||||
dispatch(toggleParametersPanel());
|
||||
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys(['shift+o'], () => {
|
||||
dispatch(togglePinParametersPanel());
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys('g', () => {
|
||||
dispatch(toggleGalleryPanel());
|
||||
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys(['shift+g'], () => {
|
||||
dispatch(togglePinGalleryPanel());
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys('1', () => {
|
||||
|
||||
@@ -59,15 +59,8 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
|
||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
{ label: '2x', value: '2' },
|
||||
{ label: '4x', value: '4' },
|
||||
];
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
|
||||
export const NUMPY_RAND_MAX = 2147483647;
|
||||
|
||||
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
|
||||
|
||||
export const NODE_MIN_WIDTH = 250;
|
||||
|
||||
@@ -88,6 +88,9 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
|
||||
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -177,6 +180,8 @@ addSocketConnectedListener();
|
||||
addSocketDisconnectedListener();
|
||||
addSocketSubscribedListener();
|
||||
addSocketUnsubscribedListener();
|
||||
addModelLoadStartedEventListener();
|
||||
addModelLoadCompletedEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
@@ -224,3 +229,5 @@ addModelSelectedListener();
|
||||
addAppStartedListener();
|
||||
addModelsLoadedListener();
|
||||
addAppConfigReceivedListener();
|
||||
|
||||
addUpscaleRequestedListener();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'session' });
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
appSocketModelLoadCompleted,
|
||||
socketModelLoadCompleted,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addModelLoadCompletedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { model_name, model_type, submodel } = action.payload.data;
|
||||
|
||||
let modelString = `${model_type} model: ${model_name}`;
|
||||
|
||||
if (submodel) {
|
||||
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||
}
|
||||
|
||||
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketModelLoadCompleted(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
appSocketModelLoadStarted,
|
||||
socketModelLoadStarted,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addModelLoadStartedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { model_name, model_type, submodel } = action.payload.data;
|
||||
|
||||
let modelString = `${model_type} model: ${model_name}`;
|
||||
|
||||
if (submodel) {
|
||||
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||
}
|
||||
|
||||
moduleLog.debug(action.payload, `Model load started (${modelString})`);
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketModelLoadStarted(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,37 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'upscale' });
|
||||
|
||||
export const upscaleRequested = createAction<{ image_name: string }>(
|
||||
`upscale/upscaleRequested`
|
||||
);
|
||||
|
||||
export const addUpscaleRequestedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: upscaleRequested,
|
||||
effect: async (
|
||||
action,
|
||||
{ dispatch, getState, take, unsubscribe, subscribe }
|
||||
) => {
|
||||
const { image_name } = action.payload;
|
||||
const { esrganModelName } = getState().postprocessing;
|
||||
|
||||
const graph = buildAdHocUpscaleGraph({
|
||||
image_name,
|
||||
esrganModelName,
|
||||
});
|
||||
|
||||
// Create a session to run the graph & wait til it's ready to invoke
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
await take(sessionCreated.fulfilled.match);
|
||||
|
||||
dispatch(sessionReadyToInvoke());
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -21,6 +21,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
|
||||
@@ -49,6 +50,7 @@ const allReducers = {
|
||||
dynamicPrompts: dynamicPromptsReducer,
|
||||
imageDeletion: imageDeletionReducer,
|
||||
lora: loraReducer,
|
||||
modelmanager: modelmanagerReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@@ -67,6 +69,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'controlNet',
|
||||
'dynamicPrompts',
|
||||
'lora',
|
||||
'modelmanager',
|
||||
];
|
||||
|
||||
export const store = configureStore({
|
||||
|
||||
@@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
|
||||
type IAIDndImageProps = {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
@@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'relative',
|
||||
minW: minSize ? minSize : undefined,
|
||||
minH: minSize ? minSize : undefined,
|
||||
userSelect: 'none',
|
||||
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
|
||||
}}
|
||||
>
|
||||
{imageDTO && (
|
||||
<ImageContextMenu imageDTO={imageDTO}>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
ref={ref}
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: fitContainer ? 'absolute' : 'relative',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'relative',
|
||||
minW: minSize ? minSize : undefined,
|
||||
minH: minSize ? minSize : undefined,
|
||||
userSelect: 'none',
|
||||
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
// If we fall back to thumbnail, it feels much snappier than the skeleton...
|
||||
fallbackSrc={imageDTO.thumbnail_url}
|
||||
// fallback={<IAILoadingImageFallback image={imageDTO} />}
|
||||
width={imageDTO.width}
|
||||
height={imageDTO.height}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
...imageSx,
|
||||
}}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<>
|
||||
<Flex
|
||||
sx={{
|
||||
minH: minSize,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
color: mode('base.500', 'base.500')(colorMode),
|
||||
...uploadButtonStyles,
|
||||
}}
|
||||
{...getUploadButtonProps()}
|
||||
>
|
||||
<input {...getUploadInputProps()} />
|
||||
<Icon
|
||||
as={FaUpload}
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
sx={{
|
||||
boxSize: 16,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: fitContainer ? 'absolute' : 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
// If we fall back to thumbnail, it feels much snappier than the skeleton...
|
||||
fallbackSrc={imageDTO.thumbnail_url}
|
||||
// fallback={<IAILoadingImageFallback image={imageDTO} />}
|
||||
width={imageDTO.width}
|
||||
height={imageDTO.height}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
...imageSx,
|
||||
}}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<>
|
||||
<Flex
|
||||
sx={{
|
||||
minH: minSize,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
color: mode('base.500', 'base.500')(colorMode),
|
||||
...uploadButtonStyles,
|
||||
}}
|
||||
{...getUploadButtonProps()}
|
||||
>
|
||||
<input {...getUploadInputProps()} />
|
||||
<Icon
|
||||
as={FaUpload}
|
||||
sx={{
|
||||
boxSize: 16,
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
/>
|
||||
)}
|
||||
{onClickReset && withResetIcon && imageDTO && (
|
||||
<IAIIconButton
|
||||
onClick={onClickReset}
|
||||
aria-label={resetTooltip}
|
||||
tooltip={resetTooltip}
|
||||
icon={resetIcon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: resetIconShadow,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
/>
|
||||
)}
|
||||
{onClickReset && withResetIcon && imageDTO && (
|
||||
<IAIIconButton
|
||||
onClick={onClickReset}
|
||||
aria-label={resetTooltip}
|
||||
tooltip={resetTooltip}
|
||||
icon={resetIcon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: resetIconShadow,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</ImageContextMenu>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -8,19 +8,34 @@ import {
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react';
|
||||
import {
|
||||
CSSProperties,
|
||||
ChangeEvent,
|
||||
KeyboardEvent,
|
||||
memo,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
|
||||
interface IAIInputProps extends InputProps {
|
||||
label?: string;
|
||||
labelPos?: 'top' | 'side';
|
||||
value?: string;
|
||||
size?: string;
|
||||
onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
|
||||
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
|
||||
}
|
||||
|
||||
const labelPosVerticalStyle: CSSProperties = {
|
||||
display: 'flex',
|
||||
flexDirection: 'row',
|
||||
alignItems: 'center',
|
||||
gap: 10,
|
||||
};
|
||||
|
||||
const IAIInput = (props: IAIInputProps) => {
|
||||
const {
|
||||
label = '',
|
||||
labelPos = 'top',
|
||||
isDisabled = false,
|
||||
isInvalid,
|
||||
formControlProps,
|
||||
@@ -51,6 +66,7 @@ const IAIInput = (props: IAIInputProps) => {
|
||||
isInvalid={isInvalid}
|
||||
isDisabled={isDisabled}
|
||||
{...formControlProps}
|
||||
style={labelPos === 'side' ? labelPosVerticalStyle : undefined}
|
||||
>
|
||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||
<Input
|
||||
|
||||
@@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal',
|
||||
marginBottom: 4,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
|
||||
@@ -9,14 +9,14 @@ export type IAISelectDataType = {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const { tooltip, inputRef, label, disabled, ...rest } = props;
|
||||
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
|
||||
|
||||
const styles = useMantineSelectStyles();
|
||||
|
||||
@@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
<Select
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormControl isRequired={required} isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
|
||||
@@ -11,7 +11,7 @@ interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
|
||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Tooltip label={tooltip} placement="top" hasArrow openDelay={500}>
|
||||
<Box ref={ref} {...others}>
|
||||
<Box>
|
||||
<Text>{label}</Text>
|
||||
|
||||
@@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
||||
export const getTimestamp = () =>
|
||||
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
setIsMouseOverBoundingBox,
|
||||
setIsMovingBoundingBox,
|
||||
setIsTransformingBoundingBox,
|
||||
setShouldSnapToGrid,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import Konva from 'konva';
|
||||
@@ -20,6 +21,7 @@ import { Vector2d } from 'konva/lib/types';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Group, Rect, Transformer } from 'react-konva';
|
||||
|
||||
const boundingBoxPreviewSelector = createSelector(
|
||||
@@ -91,6 +93,10 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
|
||||
const scaledStep = 64 * stageScale;
|
||||
|
||||
useHotkeys('N', () => {
|
||||
dispatch(setShouldSnapToGrid(!shouldSnapToGrid));
|
||||
});
|
||||
|
||||
const handleOnDragMove = useCallback(
|
||||
(e: KonvaEventObject<DragEvent>) => {
|
||||
if (!shouldSnapToGrid) {
|
||||
|
||||
@@ -139,7 +139,7 @@ const IAICanvasToolChooserOptions = () => {
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['shift+BracketLeft'],
|
||||
['Shift+BracketLeft'],
|
||||
() => {
|
||||
dispatch(
|
||||
setBrushColor({
|
||||
@@ -156,7 +156,7 @@ const IAICanvasToolChooserOptions = () => {
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['shift+BracketRight'],
|
||||
['Shift+BracketRight'],
|
||||
() => {
|
||||
dispatch(
|
||||
setBrushColor({
|
||||
|
||||
@@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
|
||||
export const selector = createSelector(
|
||||
[systemSelector, canvasSelector, isStagingSelector],
|
||||
@@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
|
||||
|
||||
const { openUploader } = useImageUploader();
|
||||
|
||||
@@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
|
||||
handleCopyImageToClipboard();
|
||||
},
|
||||
{
|
||||
enabled: () => !isStaging,
|
||||
enabled: () => !isStaging && isClipboardAPIAvailable,
|
||||
preventDefault: true,
|
||||
},
|
||||
[canvasBaseLayer, isProcessing]
|
||||
[canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
|
||||
};
|
||||
|
||||
const handleCopyImageToClipboard = () => {
|
||||
if (!isClipboardAPIAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(canvasCopiedToClipboard());
|
||||
};
|
||||
|
||||
@@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
|
||||
onClick={handleSaveToGallery}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
icon={<FaCopy />}
|
||||
onClick={handleCopyImageToClipboard}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
{isClipboardAPIAvailable && (
|
||||
<IAIIconButton
|
||||
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
icon={<FaCopy />}
|
||||
onClick={handleCopyImageToClipboard}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
)}
|
||||
<IAIIconButton
|
||||
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
|
||||
@@ -1,28 +1,30 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react';
|
||||
import {
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
} from '@chakra-ui/react';
|
||||
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
|
||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
@@ -32,36 +34,25 @@ import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
FaAsterisk,
|
||||
FaCode,
|
||||
FaCopy,
|
||||
FaDownload,
|
||||
FaExpandArrowsAlt,
|
||||
FaGrinStars,
|
||||
FaHourglassHalf,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
} from 'react-icons/fa';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { sentImageToImg2Img } from '../../store/actions';
|
||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ gallery, system, postprocessing, ui }, activeTabName) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
shouldConfirmOnDelete,
|
||||
progressImage,
|
||||
} = system;
|
||||
|
||||
const { upscalingLevel, facetoolStrength } = postprocessing;
|
||||
({ gallery, system, ui }, activeTabName) => {
|
||||
const { isProcessing, isConnected, shouldConfirmOnDelete, progressImage } =
|
||||
system;
|
||||
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
@@ -76,10 +67,6 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldConfirmOnDelete,
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
upscalingLevel,
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
@@ -102,20 +89,13 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
upscalingLevel,
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons,
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
|
||||
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
@@ -128,7 +108,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
500
|
||||
);
|
||||
|
||||
const { currentData: image, isFetching } = useGetImageDTOQuery(
|
||||
const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
|
||||
lastSelectedImage ?? skipToken
|
||||
);
|
||||
|
||||
@@ -140,42 +120,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
const metadata = metadataData?.metadata;
|
||||
|
||||
const handleCopyImageLink = useCallback(() => {
|
||||
const getImageUrl = () => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (image.image_url.startsWith('http')) {
|
||||
return image.image_url;
|
||||
}
|
||||
|
||||
return window.location.toString() + image.image_url;
|
||||
};
|
||||
|
||||
const url = getImageUrl();
|
||||
|
||||
if (!url) {
|
||||
toaster({
|
||||
title: t('toast.problemCopyingImageLink'),
|
||||
status: 'error',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(url).then(() => {
|
||||
toaster({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
}, [toaster, t, image]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(metadata);
|
||||
}, [metadata, recallAllParameters]);
|
||||
@@ -192,31 +136,34 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
useHotkeys('s', handleUseSeed, [image]);
|
||||
useHotkeys('s', handleUseSeed, [imageDTO]);
|
||||
|
||||
const handleUsePrompt = useCallback(() => {
|
||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||
|
||||
useHotkeys('p', handleUsePrompt, [image]);
|
||||
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
dispatch(initialImageSelected(imageDTO));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
useHotkeys('shift+i', handleSendToImageToImage, [image]);
|
||||
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
|
||||
|
||||
const handleClickUpscale = useCallback(() => {
|
||||
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||
}, []);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!image) {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(image));
|
||||
}, [dispatch, image]);
|
||||
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(imageDTO));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
useHotkeys(
|
||||
'Shift+U',
|
||||
@@ -227,53 +174,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
enabled: () =>
|
||||
Boolean(
|
||||
isUpscalingEnabled &&
|
||||
isESRGANAvailable &&
|
||||
!shouldDisableToolbarButtons &&
|
||||
isConnected &&
|
||||
!isProcessing &&
|
||||
upscalingLevel
|
||||
!isProcessing
|
||||
),
|
||||
},
|
||||
[
|
||||
isUpscalingEnabled,
|
||||
image,
|
||||
isESRGANAvailable,
|
||||
imageDTO,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
isProcessing,
|
||||
upscalingLevel,
|
||||
]
|
||||
);
|
||||
|
||||
const handleClickFixFaces = useCallback(() => {
|
||||
// selectedImage && dispatch(runFacetool(selectedImage));
|
||||
}, []);
|
||||
|
||||
useHotkeys(
|
||||
'Shift+R',
|
||||
() => {
|
||||
handleClickFixFaces();
|
||||
},
|
||||
{
|
||||
enabled: () =>
|
||||
Boolean(
|
||||
isFaceRestoreEnabled &&
|
||||
isGFPGANAvailable &&
|
||||
!shouldDisableToolbarButtons &&
|
||||
isConnected &&
|
||||
!isProcessing &&
|
||||
facetoolStrength
|
||||
),
|
||||
},
|
||||
|
||||
[
|
||||
isFaceRestoreEnabled,
|
||||
image,
|
||||
isGFPGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
isProcessing,
|
||||
facetoolStrength,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -282,29 +193,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
[dispatch, shouldShowImageDetails]
|
||||
);
|
||||
|
||||
const handleSendToCanvas = useCallback(() => {
|
||||
if (!image) return;
|
||||
dispatch(sentImageToCanvas());
|
||||
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
dispatch(requestCanvasRescale());
|
||||
|
||||
if (activeTabName !== 'unifiedCanvas') {
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
}
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [image, dispatch, activeTabName, toaster, t]);
|
||||
|
||||
useHotkeys(
|
||||
'i',
|
||||
() => {
|
||||
if (image) {
|
||||
if (imageDTO) {
|
||||
handleClickShowImageDetails();
|
||||
} else {
|
||||
toaster({
|
||||
@@ -315,7 +207,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[image, shouldShowImageDetails, toaster]
|
||||
[imageDTO, shouldShowImageDetails, toaster]
|
||||
);
|
||||
|
||||
const handleClickProgressImagesToggle = useCallback(() => {
|
||||
@@ -334,63 +226,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
{...props}
|
||||
>
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
tooltip={`${t('parameters.sendTo')}...`}
|
||||
isDisabled={!image}
|
||||
icon={<FaShareAlt />}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
rowGap: 2,
|
||||
}}
|
||||
>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleSendToImageToImage}
|
||||
leftIcon={<FaShare />}
|
||||
id="send-to-img2img"
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</IAIButton>
|
||||
{isCanvasEnabled && (
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleSendToCanvas}
|
||||
leftIcon={<FaShare />}
|
||||
id="send-to-canvas"
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</IAIButton>
|
||||
)}
|
||||
|
||||
{/* <IAIButton
|
||||
size="sm"
|
||||
onClick={handleCopyImage}
|
||||
leftIcon={<FaCopy />}
|
||||
>
|
||||
{t('parameters.copyImage')}
|
||||
</IAIButton> */}
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleCopyImageLink}
|
||||
leftIcon={<FaCopy />}
|
||||
>
|
||||
{t('parameters.copyImageToLink')}
|
||||
</IAIButton>
|
||||
|
||||
<Link download={true} href={image?.image_url} target="_blank">
|
||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||
{t('parameters.downloadImage')}
|
||||
</IAIButton>
|
||||
</Link>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IAIIconButton}
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
tooltip={`${t('parameters.sendTo')}...`}
|
||||
isDisabled={!imageDTO}
|
||||
icon={<FaShareAlt />}
|
||||
/>
|
||||
<MenuList motionProps={menuListMotionProps}>
|
||||
{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
@@ -419,72 +266,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
{(isUpscalingEnabled || isFaceRestoreEnabled) && (
|
||||
{isUpscalingEnabled && (
|
||||
<ButtonGroup
|
||||
isAttached={true}
|
||||
isDisabled={shouldDisableToolbarButtons}
|
||||
>
|
||||
{isFaceRestoreEnabled && (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaGrinStars />}
|
||||
aria-label={t('parameters.restoreFaces')}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
rowGap: 4,
|
||||
}}
|
||||
>
|
||||
<FaceRestoreSettings />
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
!image ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!facetoolStrength
|
||||
}
|
||||
onClick={handleClickFixFaces}
|
||||
>
|
||||
{t('parameters.restoreFaces')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
)}
|
||||
|
||||
{isUpscalingEnabled && (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
aria-label={t('parameters.upscale')}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
}}
|
||||
>
|
||||
<UpscaleSettings />
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
!image ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!upscalingLevel
|
||||
}
|
||||
onClick={handleClickUpscale}
|
||||
>
|
||||
{t('parameters.upscaleImage')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
)}
|
||||
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
|
||||
</ButtonGroup>
|
||||
)}
|
||||
|
||||
|
||||
@@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
||||
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
|
||||
|
||||
type Props = {
|
||||
imageDTO: ImageDTO;
|
||||
imageDTO: ImageDTO | undefined;
|
||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||
};
|
||||
|
||||
@@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
|
||||
const { selectionCount } = useAppSelector(selector);
|
||||
|
||||
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
{selectionCount === 1 ? (
|
||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||
) : (
|
||||
<MultipleSelectionMenuItems />
|
||||
)}
|
||||
</MenuList>
|
||||
)}
|
||||
menuButtonProps={{
|
||||
bg: 'transparent',
|
||||
_hover: { bg: 'transparent' },
|
||||
}}
|
||||
renderMenu={() =>
|
||||
imageDTO ? (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={handleContextMenu}
|
||||
>
|
||||
{selectionCount === 1 ? (
|
||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||
) : (
|
||||
<MultipleSelectionMenuItems />
|
||||
)}
|
||||
</MenuList>
|
||||
) : null
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</ContextMenu>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { MenuItem } from '@chakra-ui/react';
|
||||
import { Link, MenuItem } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
@@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useContext, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import {
|
||||
FaAsterisk,
|
||||
FaCopy,
|
||||
FaDownload,
|
||||
FaExternalLinkAlt,
|
||||
FaFolder,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
@@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
|
||||
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
|
||||
|
||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||
useCopyImageToClipboard();
|
||||
|
||||
const metadata = currentData?.metadata;
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
@@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
||||
}, [dispatch, imageDTO.image_name]);
|
||||
|
||||
const handleCopyImage = useCallback(() => {
|
||||
copyImageToClipboard(imageDTO.image_url);
|
||||
}, [copyImageToClipboard, imageDTO.image_url]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
<Link href={imageDTO.image_url} target="_blank">
|
||||
<MenuItem
|
||||
icon={<FaExternalLinkAlt />}
|
||||
onClickCapture={handleOpenInNewTab}
|
||||
>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
</Link>
|
||||
{isClipboardAPIAvailable && (
|
||||
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
|
||||
{t('parameters.copyImage')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
icon={<FaQuoteRight />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
metadata?.positive_prompt === undefined &&
|
||||
@@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
icon={<FaSeedling />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
icon={<FaAsterisk />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={!metadata}
|
||||
>
|
||||
@@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<Link download={true} href={imageDTO.image_url} target="_blank">
|
||||
<MenuItem icon={<FaDownload />} w="100%">
|
||||
{t('parameters.downloadImage')}
|
||||
</MenuItem>
|
||||
</Link>
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
|
||||
@@ -16,14 +16,13 @@ import {
|
||||
ASSETS_CATEGORIES,
|
||||
IMAGE_CATEGORIES,
|
||||
IMAGE_LIMIT,
|
||||
selectImagesAll,
|
||||
} from 'features/gallery//store/gallerySlice';
|
||||
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
|
||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||
import ImageGridListContainer from './ImageGridListContainer';
|
||||
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector, selectFilteredImages],
|
||||
@@ -180,7 +179,6 @@ const GalleryImageGrid = () => {
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
console.log({ selectedBoardId });
|
||||
|
||||
if (status !== 'rejected') {
|
||||
return (
|
||||
|
||||
@@ -110,8 +110,11 @@ const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
<Text size="xs" color="base.600">
|
||||
<Text fontWeight={600}>{label}</Text>
|
||||
<Text
|
||||
size="xs"
|
||||
sx={{ color: 'base.600', _dark: { color: 'base.500' } }}
|
||||
>
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -20,8 +20,8 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
||||
justifyContent: 'space-between',
|
||||
px: 2,
|
||||
py: 1,
|
||||
bg: 'base.300',
|
||||
_dark: { bg: 'base.700' },
|
||||
bg: 'base.100',
|
||||
_dark: { bg: 'base.900' },
|
||||
}}
|
||||
>
|
||||
<Tooltip label={nodeId}>
|
||||
@@ -30,7 +30,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
||||
sx={{
|
||||
fontWeight: 600,
|
||||
color: 'base.900',
|
||||
_dark: { color: 'base.100' },
|
||||
_dark: { color: 'base.200' },
|
||||
}}
|
||||
>
|
||||
{title}
|
||||
|
||||
@@ -59,7 +59,7 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
flexDirection: 'column',
|
||||
borderBottomRadius: 'md',
|
||||
py: 2,
|
||||
bg: 'base.200',
|
||||
bg: 'base.150',
|
||||
_dark: { bg: 'base.800' },
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { ReactFlowProvider } from 'reactflow';
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
import { Flow } from './Flow';
|
||||
import { memo } from 'react';
|
||||
import { Flow } from './Flow';
|
||||
|
||||
const NodeEditor = () => {
|
||||
return (
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Box, useToken } from '@chakra-ui/react';
|
||||
import { NODE_MIN_WIDTH } from 'app/constants';
|
||||
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { PropsWithChildren } from 'react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
selected: boolean;
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
import { ButtonGroup } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Tooltip } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
|
||||
import {
|
||||
FaCode,
|
||||
FaExpand,
|
||||
FaMinus,
|
||||
FaPlus,
|
||||
FaInfo,
|
||||
FaMapMarkerAlt,
|
||||
} from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
shouldShowGraphOverlayChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
} from '../store/nodesSlice';
|
||||
|
||||
const ViewportControls = () => {
|
||||
const { t } = useTranslation();
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
const shouldShowFieldTypeLegend = useAppSelector(
|
||||
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||
);
|
||||
const shouldShowMinimapPanel = useAppSelector(
|
||||
(state) => state.nodes.shouldShowMinimapPanel
|
||||
);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn();
|
||||
@@ -29,29 +48,64 @@ const ViewportControls = () => {
|
||||
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
||||
}, [shouldShowGraphOverlay, dispatch]);
|
||||
|
||||
const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||
}, [shouldShowFieldTypeLegend, dispatch]);
|
||||
|
||||
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
||||
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
||||
}, [shouldShowMinimapPanel, dispatch]);
|
||||
|
||||
return (
|
||||
<ButtonGroup isAttached orientation="vertical">
|
||||
<IAIIconButton
|
||||
onClick={handleClickedZoomIn}
|
||||
aria-label="Zoom In"
|
||||
icon={<FaPlus />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
onClick={handleClickedZoomOut}
|
||||
aria-label="Zoom Out"
|
||||
icon={<FaMinus />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
onClick={handleClickedFitView}
|
||||
aria-label="Fit to Viewport"
|
||||
icon={<FaExpand />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isChecked={shouldShowGraphOverlay}
|
||||
onClick={handleClickedToggleGraphOverlay}
|
||||
aria-label="Show/Hide Graph"
|
||||
icon={<FaCode />}
|
||||
/>
|
||||
<Tooltip label={t('nodes.zoomInNodes')}>
|
||||
<IAIIconButton onClick={handleClickedZoomIn} icon={<FaPlus />} />
|
||||
</Tooltip>
|
||||
<Tooltip label={t('nodes.zoomOutNodes')}>
|
||||
<IAIIconButton onClick={handleClickedZoomOut} icon={<FaMinus />} />
|
||||
</Tooltip>
|
||||
<Tooltip label={t('nodes.fitViewportNodes')}>
|
||||
<IAIIconButton onClick={handleClickedFitView} icon={<FaExpand />} />
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={
|
||||
shouldShowGraphOverlay
|
||||
? t('nodes.hideGraphNodes')
|
||||
: t('nodes.showGraphNodes')
|
||||
}
|
||||
>
|
||||
<IAIIconButton
|
||||
isChecked={shouldShowGraphOverlay}
|
||||
onClick={handleClickedToggleGraphOverlay}
|
||||
icon={<FaCode />}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={
|
||||
shouldShowFieldTypeLegend
|
||||
? t('nodes.hideLegendNodes')
|
||||
: t('nodes.showLegendNodes')
|
||||
}
|
||||
>
|
||||
<IAIIconButton
|
||||
isChecked={shouldShowFieldTypeLegend}
|
||||
onClick={handleClickedToggleFieldTypeLegend}
|
||||
icon={<FaInfo />}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={
|
||||
shouldShowMinimapPanel
|
||||
? t('nodes.hideMinimapnodes')
|
||||
: t('nodes.showMinimapnodes')
|
||||
}
|
||||
>
|
||||
<IAIIconButton
|
||||
isChecked={shouldShowMinimapPanel}
|
||||
onClick={handleClickedToggleMiniMapPanel}
|
||||
icon={<FaMapMarkerAlt />}
|
||||
/>
|
||||
</Tooltip>
|
||||
</ButtonGroup>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -12,7 +12,10 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
@@ -23,6 +26,7 @@ const ModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: onnxModels } = useGetOnnxModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
@@ -44,17 +48,39 @@ const ModelInputFieldComponent = (
|
||||
});
|
||||
});
|
||||
|
||||
if (onnxModels) {
|
||||
forEach(onnxModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, onnxModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[
|
||||
(mainModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
|
||||
] ||
|
||||
onnxModels?.entities[
|
||||
`${field.value?.base_model}/onnx/${field.value?.model_name}`
|
||||
]) ??
|
||||
null,
|
||||
[
|
||||
field.value?.base_model,
|
||||
field.value?.model_name,
|
||||
mainModels?.entities,
|
||||
onnxModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { MiniMap } from 'reactflow';
|
||||
@@ -12,6 +14,10 @@ const MinimapPanel = () => {
|
||||
}
|
||||
);
|
||||
|
||||
const shouldShowMinimapPanel = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowMinimapPanel
|
||||
);
|
||||
|
||||
const nodeColor = useColorModeValue(
|
||||
'var(--invokeai-colors-accent-300)',
|
||||
'var(--invokeai-colors-accent-700)'
|
||||
@@ -23,15 +29,19 @@ const MinimapPanel = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<MiniMap
|
||||
nodeStrokeWidth={3}
|
||||
pannable
|
||||
zoomable
|
||||
nodeBorderRadius={30}
|
||||
style={miniMapStyle}
|
||||
nodeColor={nodeColor}
|
||||
maskColor={maskColor}
|
||||
/>
|
||||
<>
|
||||
{shouldShowMinimapPanel && (
|
||||
<MiniMap
|
||||
nodeStrokeWidth={3}
|
||||
pannable
|
||||
zoomable
|
||||
nodeBorderRadius={30}
|
||||
style={miniMapStyle}
|
||||
nodeColor={nodeColor}
|
||||
maskColor={maskColor}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -9,10 +9,13 @@ const TopRightPanel = () => {
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
const shouldShowFieldTypeLegend = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
|
||||
);
|
||||
|
||||
return (
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||
</Panel>
|
||||
);
|
||||
|
||||
@@ -32,6 +32,8 @@ export type NodesState = {
|
||||
invocationTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
shouldShowGraphOverlay: boolean;
|
||||
shouldShowFieldTypeLegend: boolean;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
editorInstance: ReactFlowInstance | undefined;
|
||||
};
|
||||
|
||||
@@ -42,6 +44,8 @@ export const initialNodesState: NodesState = {
|
||||
invocationTemplates: {},
|
||||
connectionStartParams: null,
|
||||
shouldShowGraphOverlay: false,
|
||||
shouldShowFieldTypeLegend: false,
|
||||
shouldShowMinimapPanel: true,
|
||||
editorInstance: undefined,
|
||||
};
|
||||
|
||||
@@ -125,6 +129,15 @@ const nodesSlice = createSlice({
|
||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowGraphOverlay = action.payload;
|
||||
},
|
||||
shouldShowFieldTypeLegendChanged: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
) => {
|
||||
state.shouldShowFieldTypeLegend = action.payload;
|
||||
},
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
},
|
||||
nodeTemplatesBuilt: (
|
||||
state,
|
||||
action: PayloadAction<Record<string, InvocationTemplate>>
|
||||
@@ -161,6 +174,8 @@ export const {
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
shouldShowGraphOverlayChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
nodeTemplatesBuilt,
|
||||
nodeEditorReset,
|
||||
imageCollectionFieldValueChanged,
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
@@ -17,7 +18,8 @@ import {
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
baseNodeId: string,
|
||||
modelLoader: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
@@ -40,6 +42,10 @@ export const addLoRAsToGraph = (
|
||||
!(
|
||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === ONNX_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
@@ -74,12 +80,11 @@ export const addLoRAsToGraph = (
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: modelLoader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user