mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 10:48:12 -05:00
Compare commits
201 Commits
feat/batch
...
v3.0.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06296896a9 | ||
|
|
a7399aca0c | ||
|
|
d1ea8b1e98 | ||
|
|
f851ad7ba0 | ||
|
|
591838a84b | ||
|
|
c0c2ab3dcf | ||
|
|
56023bc725 | ||
|
|
2ef6a8995b | ||
|
|
d0fee93aac | ||
|
|
1bfe9835cf | ||
|
|
8e7eae6cc7 | ||
|
|
f6522c8971 | ||
|
|
a969707e45 | ||
|
|
6c8e898f09 | ||
|
|
7bad9bcf53 | ||
|
|
d42b45116f | ||
|
|
d4812bbc8d | ||
|
|
3cd05cf6bf | ||
|
|
2564301aeb | ||
|
|
da0efeaa7f | ||
|
|
49cce1eec6 | ||
|
|
c8fbaf54b6 | ||
|
|
f86d388786 | ||
|
|
cd2c688562 | ||
|
|
2d29ac6f0d | ||
|
|
2c2b731386 | ||
|
|
2f68a1a76c | ||
|
|
930e7bc754 | ||
|
|
7d4ace962a | ||
|
|
06842f8e0a | ||
|
|
c82da330db | ||
|
|
628df4ec98 | ||
|
|
16b956616f | ||
|
|
604cc17a3a | ||
|
|
37c9b85549 | ||
|
|
8b39b67ec7 | ||
|
|
a933977861 | ||
|
|
dfb41d8461 | ||
|
|
4d5169e16d | ||
|
|
f56f19710d | ||
|
|
e77400ab62 | ||
|
|
13347f6aec | ||
|
|
a9bf387e5e | ||
|
|
8258c87a9f | ||
|
|
1b1b399fd0 | ||
|
|
a8d3e078c0 | ||
|
|
6ed7ba57dd | ||
|
|
2b3b77a276 | ||
|
|
8b8ec68b30 | ||
|
|
e20af5aef0 | ||
|
|
57e8ec9488 | ||
|
|
734a9e4271 | ||
|
|
fe924daee3 | ||
|
|
750f09fbed | ||
|
|
4df581811e | ||
|
|
eb70bc2ae4 | ||
|
|
809705c30d | ||
|
|
f0918edf98 | ||
|
|
a846d82fa1 | ||
|
|
22f7cf0638 | ||
|
|
25c669b1d6 | ||
|
|
4367061b19 | ||
|
|
0fd13d3604 | ||
|
|
72a3e776b2 | ||
|
|
af044007d5 | ||
|
|
f272a44feb | ||
|
|
8469d3e95a | ||
|
|
ae17d01e1d | ||
|
|
f3d3316558 | ||
|
|
5a6cefb0ea | ||
|
|
1a6f5f0860 | ||
|
|
5bfd6cb66f | ||
|
|
59caff7ff0 | ||
|
|
6487e7d906 | ||
|
|
77033eabd3 | ||
|
|
b80abdd101 | ||
|
|
006d782cc8 | ||
|
|
d09dfc3e9b | ||
|
|
66f524cae7 | ||
|
|
9ba50130a1 | ||
|
|
d4cf2d2666 | ||
|
|
b8b589c150 | ||
|
|
d93900a8de | ||
|
|
7f4c387080 | ||
|
|
80876bbbd1 | ||
|
|
7a4ff4c089 | ||
|
|
44bf308192 | ||
|
|
12e51c84ae | ||
|
|
b2eb83deff | ||
|
|
0ccc3b509e | ||
|
|
4043a4c21c | ||
|
|
c8ceb96091 | ||
|
|
83f75750a9 | ||
|
|
dc96a3e79d | ||
|
|
c076f1397e | ||
|
|
2568aafc0b | ||
|
|
65ed224bfc | ||
|
|
b6e369c745 | ||
|
|
ecabfc252b | ||
|
|
da96a41103 | ||
|
|
d162b78767 | ||
|
|
eb6c317f04 | ||
|
|
6d7223238f | ||
|
|
8607d124c5 | ||
|
|
23497bf759 | ||
|
|
b10cf20eb1 | ||
|
|
3d93851dba | ||
|
|
9bacd77a79 | ||
|
|
1b158f62c4 | ||
|
|
6ad565d84c | ||
|
|
04229082d6 | ||
|
|
03c27412f7 | ||
|
|
f0613bb0ef | ||
|
|
0e9f92b868 | ||
|
|
7d0cc6ec3f | ||
|
|
2f8b928486 | ||
|
|
0d3c27f46c | ||
|
|
cff91f06d3 | ||
|
|
1d5d187ba1 | ||
|
|
1ac14a1e43 | ||
|
|
cfc3a20565 | ||
|
|
05ae4e283c | ||
|
|
f06fee4581 | ||
|
|
9091e19de8 | ||
|
|
0a0b7141af | ||
|
|
1deca89fde | ||
|
|
446fb4a438 | ||
|
|
ab5d938a1d | ||
|
|
9942af756a | ||
|
|
06742faca7 | ||
|
|
d2bddf7f91 | ||
|
|
91ebf9f76e | ||
|
|
bf94412d14 | ||
|
|
e080fd1e08 | ||
|
|
eeef1e08f8 | ||
|
|
b3b94b5a8d | ||
|
|
5c9787c145 | ||
|
|
cf72eba15c | ||
|
|
a6f9396a30 | ||
|
|
118d5b387b | ||
|
|
02d2cc758d | ||
|
|
db545f8801 | ||
|
|
b0d72b15b3 | ||
|
|
4e0949fa55 | ||
|
|
f028342f5b | ||
|
|
7021467048 | ||
|
|
26ef5249b1 | ||
|
|
87424be95d | ||
|
|
366952f810 | ||
|
|
450e95de59 | ||
|
|
0ba8a0ea6c | ||
|
|
f4981f26d5 | ||
|
|
6bc21984c6 | ||
|
|
43d6312587 | ||
|
|
0d125bf3e4 | ||
|
|
921ccad04d | ||
|
|
05c9207e7b | ||
|
|
3fc789a7ee | ||
|
|
008362918e | ||
|
|
8fc75a71ee | ||
|
|
82d259f43b | ||
|
|
ec48779080 | ||
|
|
bc20fe4cb5 | ||
|
|
5de42be4a6 | ||
|
|
818c55cd53 | ||
|
|
0db1e97119 | ||
|
|
29ac252501 | ||
|
|
880727436c | ||
|
|
77c5c18542 | ||
|
|
ed76250dba | ||
|
|
4d22cafdad | ||
|
|
1f9e984b0d | ||
|
|
8a4e5f73aa | ||
|
|
4599575e65 | ||
|
|
242d860a47 | ||
|
|
0c1a7e72d4 | ||
|
|
11a44b944d | ||
|
|
fd7b842419 | ||
|
|
5998509888 | ||
|
|
7292d89108 | ||
|
|
437f45a97f | ||
|
|
13ef33ed64 | ||
|
|
86d8b46fca | ||
|
|
df53b62048 | ||
|
|
55d3f04476 | ||
|
|
72ebe2ce68 | ||
|
|
7cd8b2f207 | ||
|
|
bacdf985f1 | ||
|
|
e3519052ae | ||
|
|
adfd1e52f4 | ||
|
|
0e48c98330 | ||
|
|
ff1c40747e | ||
|
|
dbfd1bcb5e | ||
|
|
ccceb32a85 | ||
|
|
21617e60e1 | ||
|
|
35dd58e273 | ||
|
|
86b8b69e88 | ||
|
|
bc9a5038fd | ||
|
|
b163ae6a4d | ||
|
|
dca685ac25 | ||
|
|
e70bedba7d |
14
.github/workflows/style-checks.yml
vendored
14
.github/workflows/style-checks.yml
vendored
@@ -1,13 +1,14 @@
|
||||
name: Black # TODO: add isort and flake8 later
|
||||
name: style checks
|
||||
# just formatting for now
|
||||
# TODO: add isort and flake8 later
|
||||
|
||||
on:
|
||||
pull_request: {}
|
||||
pull_request:
|
||||
push:
|
||||
branches: master
|
||||
tags: "*"
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
black:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -19,8 +20,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install --upgrade pip wheel
|
||||
pip install .[test]
|
||||
pip install black
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
|
||||
50
.github/workflows/test-invoke-pip-skip.yml
vendored
50
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: Test invoke.py pip
|
||||
|
||||
# This is a dummy stand-in for the actual tests
|
||||
# we don't need to run python tests on non-Python changes
|
||||
# But PRs require passing tests to be mergeable
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**'
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- '!tests/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: skip
|
||||
run: echo "no build required"
|
||||
24
.github/workflows/test-invoke-pip.yml
vendored
24
.github/workflows/test-invoke-pip.yml
vendored
@@ -3,16 +3,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- 'tests/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -65,10 +56,23 @@ jobs:
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v37
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -76,6 +80,7 @@ jobs:
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
@@ -83,6 +88,7 @@ jobs:
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
|
||||
38
README.md
38
README.md
@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
@@ -184,8 +184,9 @@ the command `npm install -g yarn` if needed)
|
||||
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
Don't miss the dot at the end!
|
||||
|
||||
7. Launch the web server (do it every time you run InvokeAI):
|
||||
|
||||
@@ -193,15 +194,9 @@ the command `npm install -g yarn` if needed)
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
8. Build Node.js assets
|
||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
|
||||
```terminal
|
||||
cd invokeai/frontend/web/
|
||||
yarn vite build
|
||||
```
|
||||
|
||||
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
|
||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||
@@ -311,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
|
||||
You may now launch the WebUI in the usual way, by selecting option [1]
|
||||
from the launcher script
|
||||
|
||||
#### Migration Caveats
|
||||
#### Migrating Images
|
||||
|
||||
The migration script will migrate your invokeai settings and models,
|
||||
including textual inversion models, LoRAs and merges that you may have
|
||||
installed previously. However it does **not** migrate the generated
|
||||
images stored in your 2.3-format outputs directory. You will need to
|
||||
manually import selected images into the 3.0 gallery via drag-and-drop.
|
||||
images stored in your 2.3-format outputs directory. To do this, you
|
||||
need to run an additional step:
|
||||
|
||||
1. From a working InvokeAI 3.0 root directory, start the launcher and
|
||||
enter menu option [8] to open the "developer's console".
|
||||
|
||||
2. At the developer's console command line, type the command:
|
||||
|
||||
```bash
|
||||
invokeai-import-images
|
||||
```
|
||||
|
||||
3. This will lead you through the process of confirming the desired
|
||||
source and destination for the imported images. The images will
|
||||
appear in the gallery board of your choice, and contain the
|
||||
original prompt, model name, and other parameters used to generate
|
||||
the image.
|
||||
|
||||
(Many kudos to **techjedi** for contributing this script.)
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ experimental versions later.
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***Autoimport FolderLICENSE***
|
||||
- ***LICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
@@ -471,7 +471,7 @@ Then type the following commands:
|
||||
|
||||
=== "NVIDIA System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -192,8 +192,10 @@ manager, please follow these steps:
|
||||
your outputs.
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
|
||||
Don't miss the dot at the end of the command!
|
||||
|
||||
The script `invokeai-configure` will interactively guide you through the
|
||||
process of downloading and installing the weights files needed for InvokeAI.
|
||||
@@ -225,12 +227,6 @@ manager, please follow these steps:
|
||||
|
||||
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
=== "local Webserver"
|
||||
|
||||
```bash
|
||||
@@ -243,6 +239,12 @@ manager, please follow these steps:
|
||||
invokeai --web --host 0.0.0.0
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
If you choose the run the web interface, point your browser at
|
||||
http://localhost:9090 in order to load the GUI.
|
||||
|
||||
@@ -310,7 +312,7 @@ installation protocol (important!)
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
```bash
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -354,7 +356,7 @@ you can do so using this unsupported recipe:
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.10
|
||||
conda activate invokeai
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
invokeai-configure --root ~/invokeai
|
||||
invokeai --root ~/invokeai --web
|
||||
```
|
||||
|
||||
@@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
|
||||
nvidia-cuda-toolkit package. It is out of date and will cause
|
||||
conflicts among the NVIDIA driver and binaries.**
|
||||
|
||||
Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
|
||||
and use the target selection wizard to choose your operating system,
|
||||
hardware platform, and preferred installation method (e.g. "local"
|
||||
versus "network").
|
||||
Go to [CUDA Toolkit
|
||||
Downloads](https://developer.nvidia.com/cuda-downloads), and use the
|
||||
target selection wizard to choose your operating system, hardware
|
||||
platform, and preferred installation method (e.g. "local" versus
|
||||
"network").
|
||||
|
||||
This will provide you with a downloadable install file or, depending
|
||||
on your choices, a recipe for downloading and running a install shell
|
||||
@@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/cu117` as described in the [Manual
|
||||
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
## :simple-amd: ROCm
|
||||
|
||||
@@ -124,7 +124,7 @@ installation. Examples:
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
invokeai-model-install --add http://civitai.com/2860
|
||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
@@ -170,4 +170,4 @@ elsewhere on disk and they will be autoimported. You can also create
|
||||
subfolders and organize them as you wish.
|
||||
|
||||
The location of the autoimport directories are controlled by settings
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
|
||||
@@ -28,18 +28,21 @@ command line, then just be sure to activate it's virtual environment.
|
||||
Then run the following three commands:
|
||||
|
||||
```sh
|
||||
pip install xformers==0.0.16rc425
|
||||
pip install triton
|
||||
pip install xformers~=0.0.19
|
||||
pip install triton # WON'T WORK ON WINDOWS
|
||||
python -m xformers.info output
|
||||
```
|
||||
|
||||
The first command installs `xformers`, the second installs the
|
||||
`triton` training accelerator, and the third prints out the `xformers`
|
||||
installation status. If all goes well, you'll see a report like the
|
||||
installation status. On Windows, please omit the `triton` package,
|
||||
which is not available on that platform.
|
||||
|
||||
If all goes well, you'll see a report like the
|
||||
following:
|
||||
|
||||
```sh
|
||||
xFormers 0.0.16rc425
|
||||
xFormers 0.0.20
|
||||
memory_efficient_attention.cutlassF: available
|
||||
memory_efficient_attention.cutlassB: available
|
||||
memory_efficient_attention.flshattF: available
|
||||
@@ -48,22 +51,28 @@ memory_efficient_attention.smallkF: available
|
||||
memory_efficient_attention.smallkB: available
|
||||
memory_efficient_attention.tritonflashattF: available
|
||||
memory_efficient_attention.tritonflashattB: available
|
||||
indexing.scaled_index_addF: available
|
||||
indexing.scaled_index_addB: available
|
||||
indexing.index_select: available
|
||||
swiglu.dual_gemm_silu: available
|
||||
swiglu.gemm_fused_operand_sum: available
|
||||
swiglu.fused.p.cpp: available
|
||||
is_triton_available: True
|
||||
is_functorch_available: False
|
||||
pytorch.version: 1.13.1+cu117
|
||||
pytorch.version: 2.0.1+cu118
|
||||
pytorch.cuda: available
|
||||
gpu.compute_capability: 8.6
|
||||
gpu.name: NVIDIA RTX A2000 12GB
|
||||
gpu.compute_capability: 8.9
|
||||
gpu.name: NVIDIA GeForce RTX 4070
|
||||
build.info: available
|
||||
build.cuda_version: 1107
|
||||
build.python_version: 3.10.9
|
||||
build.torch_version: 1.13.1+cu117
|
||||
build.cuda_version: 1108
|
||||
build.python_version: 3.10.11
|
||||
build.torch_version: 2.0.1+cu118
|
||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||
build.env.XFORMERS_BUILD_TYPE: Release
|
||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||
build.env.NVCC_FLAGS: None
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
|
||||
build.nvcc_version: 11.8.89
|
||||
source.privacy: open source
|
||||
```
|
||||
|
||||
@@ -83,14 +92,14 @@ installed from source. These instructions were written for a system
|
||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||
adapt this recipe.
|
||||
|
||||
#### 1. Install CUDA Toolkit 11.7
|
||||
#### 1. Install CUDA Toolkit 11.8
|
||||
|
||||
You will need the CUDA developer's toolkit in order to compile and
|
||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||
and use the target selection wizard to choose your platform and Linux
|
||||
distribution. Select an installer type of "runfile (local)" at the
|
||||
last step.
|
||||
@@ -101,17 +110,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
|
||||
x86_64 system is:
|
||||
|
||||
```
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||
sudo sh cuda_11.7.0_515.43.04_linux.run
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
|
||||
Rather than cut-and-paste this example, We recommend that you walk
|
||||
through the toolkit wizard in order to get the most up to date
|
||||
installer for your system.
|
||||
|
||||
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
|
||||
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||
|
||||
If you are using InvokeAI 2.3 or higher, these will already be
|
||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||
installed. If not, you can check whether you have the needed libraries
|
||||
using a quick command. Activate the invokeai virtual environment,
|
||||
either by entering the "developer's console", or manually with a
|
||||
@@ -124,7 +133,7 @@ Then run the command:
|
||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||
```
|
||||
|
||||
If it prints __1.13.1+cu117__ you're good. If not, you can install the
|
||||
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||
most up to date libraries with this command:
|
||||
|
||||
```sh
|
||||
|
||||
@@ -348,7 +348,7 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
@@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Optional
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@@ -30,6 +29,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ logger = InvokeAILogger.getLogger()
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Optional[Invoker] = None
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||
@@ -68,8 +68,9 @@ class ApiDependencies:
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@@ -128,6 +129,7 @@ class ApiDependencies:
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
operation_id="add_image_to_board",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
@@ -29,26 +35,78 @@ async def create_board_image(
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
operation_id="remove_image_from_board",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch",
|
||||
operation_id="add_images_to_board",
|
||||
responses={
|
||||
201: {"description": "Images were added to board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_images_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch/delete",
|
||||
operation_id="remove_images_from_board",
|
||||
responses={
|
||||
201: {"description": "Images were removed from board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_images_from_board(
|
||||
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import io
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
@@ -25,7 +24,7 @@ IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/",
|
||||
"/upload",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
@@ -77,7 +76,7 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
@@ -103,7 +102,7 @@ async def clear_intermediates() -> int:
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
@@ -120,7 +119,7 @@ async def update_image(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="get_image_dto",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
@@ -136,7 +135,7 @@ async def get_image_dto(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
"/i/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageMetadata,
|
||||
)
|
||||
@@ -151,8 +150,9 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/full",
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -187,7 +187,7 @@ async def get_image_full(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/thumbnail",
|
||||
"/i/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -216,7 +216,7 @@ async def get_image_thumbnail(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/urls",
|
||||
"/i/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
@@ -265,3 +265,24 @@ async def list_image_dtos(
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesFromListResult:
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
@@ -104,8 +104,12 @@ async def update_model(
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
|
||||
@@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
@@ -311,6 +312,7 @@ def invoke_cli():
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
@@ -109,12 +109,15 @@ class CompelInvocation(BaseInvocation):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
@@ -173,7 +176,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
@@ -197,12 +200,15 @@ class SDXLPromptInvocationBase:
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
@@ -210,8 +216,8 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
@@ -247,7 +253,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
@@ -271,12 +277,15 @@ class SDXLPromptInvocationBase:
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
@@ -284,8 +293,8 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
@@ -357,11 +366,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -415,7 +424,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -467,11 +477,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -525,7 +535,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import contextmanager, ContextDecorator
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
from .image import ImageOutput
|
||||
from .model import UNetField, VaeField
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
@@ -184,6 +181,8 @@ class InpaintInvocation(BaseInvocation):
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
vae.to(dtype=unet.dtype)
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
@@ -193,8 +192,6 @@ class InpaintInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
import cv2
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
@@ -500,7 +501,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
image_arr = image_arr * (self.max - self.min) + self.min
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@@ -650,3 +651,143 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = numpy.array(pil_image.convert("HSV"))
|
||||
|
||||
# Convert hue from 0..360 to 0..256
|
||||
hue = int(256 * ((self.hue % 360) / 360))
|
||||
|
||||
# Increment each hue and wrap around at 255
|
||||
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the luminosity (value)
|
||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the saturation
|
||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -5,16 +5,27 @@ from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
@@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
@@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from ...version import __version__
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -10,18 +11,20 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
class LoRAMetadataField(BaseModelExcludeNull):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class CoreMetadata(BaseModel):
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
@@ -70,7 +73,7 @@ class CoreMetadata(BaseModel):
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
"""An image's generation metadata"""
|
||||
|
||||
metadata: Optional[dict] = Field(
|
||||
|
||||
@@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
@@ -76,18 +75,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
# stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
# )
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
||||
@@ -293,10 +293,20 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
num_inference_steps = self.steps
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
@@ -543,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
# apply denoising_start
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||
|
||||
@@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
|
||||
@@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -93,10 +92,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
self._services.board_image_records.remove_image_from_board(image_name)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
|
||||
@@ -24,11 +24,10 @@ InvokeAI:
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_cache_size: 6
|
||||
max_vram_cache_size: 2.7
|
||||
max_vram_cache_size: 0.5
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
@@ -165,7 +164,7 @@ import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
@@ -173,6 +172,7 @@ from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_ty
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
@@ -189,7 +189,12 @@ class InvokeAISettings(BaseSettings):
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt, name))
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
@@ -274,7 +279,7 @@ class InvokeAISettings(BaseSettings):
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf", "cached_root"]
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
@@ -282,15 +287,10 @@ class InvokeAISettings(BaseSettings):
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"gpu_mem_reserved",
|
||||
"max_loaded_models",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"restore",
|
||||
"root",
|
||||
"nsfw_checker",
|
||||
"cached_root",
|
||||
]
|
||||
|
||||
class Config:
|
||||
@@ -356,7 +356,7 @@ class InvokeAISettings(BaseSettings):
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
@@ -389,21 +389,17 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
|
||||
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
@@ -415,8 +411,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
@@ -424,9 +419,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
@@ -472,15 +469,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
"""
|
||||
# we cache value of root to protect against it being '.' and the cwd changing
|
||||
if self.cached_root:
|
||||
root = self.cached_root
|
||||
elif self.root:
|
||||
if self.root:
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
root = self.find_root()
|
||||
self.cached_root = root
|
||||
return self.cached_root
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self.root = root # insulate ourselves from relative paths that may change
|
||||
return root
|
||||
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
|
||||
@@ -289,9 +289,10 @@ class ImageService(ImageServiceABC):
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata()
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
@@ -303,7 +304,6 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
|
||||
@@ -32,6 +32,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
@@ -47,6 +48,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.board_images = board_images
|
||||
@@ -61,4 +63,5 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
223
invokeai/app/services/invocation_stats.py
Normal file
223
invokeai/app/services/invocation_stats.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
|
||||
Typical output:
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||
|
||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
:param graph_execution_manager: Graph execution manager for this session
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self):
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
|
||||
class StatsContext:
|
||||
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.collector.update_invocation_stats(
|
||||
self.graph_id,
|
||||
self.invocation.type,
|
||||
time.time() - self.start_time,
|
||||
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
"""Zero the statistics for the indicated graph."""
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||
|
||||
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Floating point seconds used by node's exection
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
stats.calls += 1
|
||||
stats.time_used += time_used
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed if when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
total_time = 0
|
||||
logger.info(f"Graph stats: {graph_id}")
|
||||
logger.info("Node Calls Seconds VRAM Used")
|
||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||
total_time += stats.time_used
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
completed.add(graph_id)
|
||||
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
@@ -3,9 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management import (
|
||||
@@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@@ -292,7 +293,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@@ -396,7 +397,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
@@ -416,7 +417,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
@@ -429,7 +430,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@@ -478,7 +479,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
@@ -573,9 +574,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
@@ -633,8 +634,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
|
||||
8
invokeai/app/services/models/board_image.py
Normal file
8
invokeai/app/services/models/board_image.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardImage(BaseModelExcludeNull):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
class BoardRecord(BaseModelExcludeNull):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
class ImageRecord(BaseModelExcludeNull):
|
||||
"""Deserialized image record without metadata."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -40,7 +41,7 @@ class ImageRecord(BaseModel):
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
@@ -60,7 +61,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -76,11 +77,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
||||
image_record: ImageRecord,
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ..models.exceptions import CanceledException
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_stats import InvocationStatsServiceBase
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
@@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
@@ -83,35 +86,38 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
@@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
|
||||
@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_basename}/full"
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
rng = np.random.default_rng(seed=0)
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
23
invokeai/app/util/model_exclude_null.py
Normal file
23
invokeai/app/util/model_exclude_null.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
"""
|
||||
We want to exclude null values from objects that make their way to the client.
|
||||
|
||||
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
|
||||
dict method to do this.
|
||||
|
||||
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
|
||||
"""
|
||||
|
||||
|
||||
class BaseModelExcludeNull(BaseModel):
|
||||
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Override the default dict method to exclude None values in the response
|
||||
"""
|
||||
kwargs.pop("exclude_none", None)
|
||||
return super().dict(*args, exclude_none=True, **kwargs)
|
||||
|
||||
pass
|
||||
@@ -1,25 +1,11 @@
|
||||
"""
|
||||
invokeai.backend.generator.img2img descends from .generator
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import logging
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
sampler,
|
||||
@@ -42,51 +28,4 @@ class Img2Img(Generator):
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image,
|
||||
strength,
|
||||
steps,
|
||||
conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
seed=seed,
|
||||
)
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation")
|
||||
|
||||
@@ -377,3 +377,11 @@ class Inpaint(Img2Img):
|
||||
)
|
||||
|
||||
return corrected_result
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
@@ -12,16 +12,17 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
for model in [
|
||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
"bert-base-uncased",
|
||||
"clip-vit-large-patch14",
|
||||
"sd-vae-ft-mse",
|
||||
"stable-diffusion-2-clip",
|
||||
"stable-diffusion-safety-checker",
|
||||
]:
|
||||
path = config.models_path / f"core/convert/{model}"
|
||||
assert path.exists(), f"{path} is missing"
|
||||
if not config.ignore_missing_core_models:
|
||||
for model in [
|
||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
"bert-base-uncased",
|
||||
"clip-vit-large-patch14",
|
||||
"sd-vae-ft-mse",
|
||||
"stable-diffusion-2-clip",
|
||||
"stable-diffusion-safety-checker",
|
||||
]:
|
||||
path = config.models_path / f"core/convert/{model}"
|
||||
assert path.exists(), f"{path} is missing"
|
||||
except Exception as e:
|
||||
print()
|
||||
print(f"An exception has occurred: {str(e)}")
|
||||
@@ -32,5 +33,10 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
print(
|
||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||
)
|
||||
print(
|
||||
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
|
||||
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
|
||||
"always come back and install these core models in the future.)"
|
||||
)
|
||||
input("Press any key to continue...")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -10,15 +10,17 @@ import sys
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import psutil
|
||||
import shutil
|
||||
import textwrap
|
||||
import torch
|
||||
import traceback
|
||||
import yaml
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
@@ -44,6 +46,8 @@ from invokeai.app.services.config import (
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
@@ -53,6 +57,7 @@ from invokeai.frontend.install.widgets import (
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
WindowTooSmallException,
|
||||
)
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
@@ -61,6 +66,7 @@ from invokeai.backend.install.model_install_backend import (
|
||||
ModelInstall,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@@ -76,6 +82,13 @@ Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@@ -86,6 +99,12 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
if not any(errors):
|
||||
@@ -376,15 +395,47 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
IntTitleSlider,
|
||||
name="Size of the RAM cache used for fast model switching (GB)",
|
||||
value=old_opts.max_cache_size,
|
||||
out_of=20,
|
||||
lowest=3,
|
||||
begin_entry_at=6,
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
||||
out_of=round(MAX_RAM),
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
if HAS_CUDA:
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_vram_cache_size = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
||||
out_of=round(MAX_VRAM * 2) / 2,
|
||||
lowest=0.0,
|
||||
relx=8,
|
||||
step=0.25,
|
||||
scroll_exit=True,
|
||||
)
|
||||
else:
|
||||
self.max_vram_cache_size = DummyWidgetValue.zero
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@@ -401,7 +452,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
@@ -476,6 +527,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
@@ -553,6 +605,16 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def clip(value: float, range: tuple[float, float], step: float) -> float:
|
||||
minimum, maximum = range
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
if value > maximum:
|
||||
value = maximum
|
||||
return round(value / step) * step
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
logger.info("Initializing InvokeAI runtime directory")
|
||||
@@ -592,13 +654,13 @@ def maybe_create_models_yaml(root: Path):
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
@@ -654,10 +716,13 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
try:
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
except ValidationError as e:
|
||||
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
@@ -777,6 +842,7 @@ def main():
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
@@ -802,6 +868,8 @@ def main():
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
|
||||
@@ -591,7 +591,6 @@ script, which will perform a full upgrade in place.""",
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||
if not dest_is_setup:
|
||||
import invokeai.frontend.install.invokeai_configure
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
@@ -13,6 +13,7 @@ import requests
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
import onnx
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
@@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@@ -99,9 +101,9 @@ class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token: str = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
@@ -303,7 +305,7 @@ class ModelInstall(object):
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
||||
if "model_index.json" in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||
elif "unet/model.onnx" in files:
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
@@ -416,15 +418,25 @@ class ModelInstall(object):
|
||||
does a save_pretrained() to the indicated staging area.
|
||||
"""
|
||||
_, name = repo_id.split("/")
|
||||
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
||||
precision = torch_dtype(choose_torch_device())
|
||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||
|
||||
model = None
|
||||
for revision in revisions:
|
||||
for variant in variants:
|
||||
try:
|
||||
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
pass
|
||||
model = DiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
variant=variant,
|
||||
torch_dtype=precision,
|
||||
safety_checker=None,
|
||||
)
|
||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
if "fp16" not in str(e):
|
||||
print(e)
|
||||
|
||||
if model:
|
||||
break
|
||||
|
||||
if not model:
|
||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||
return None
|
||||
|
||||
@@ -13,3 +13,4 @@ from .models import (
|
||||
DuplicateModelException,
|
||||
)
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
from .lora import ModelPatcher
|
||||
|
||||
@@ -20,424 +20,6 @@ from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: rename and split this file
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||
multiplier: float,
|
||||
):
|
||||
if type(module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight()
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
return (
|
||||
op(
|
||||
*input_h,
|
||||
(weight + bias).view(module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
)
|
||||
* multiplier
|
||||
* scale
|
||||
)
|
||||
|
||||
def get_weight(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoRAModel: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
_device: torch.device
|
||||
_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._name = name
|
||||
self._device = device or torch.cpu
|
||||
self._dtype = dtype or torch.float32
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> LoRAModel:
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
||||
return
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
@@ -516,6 +98,26 @@ class ModelPatcher:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
@@ -541,7 +143,7 @@ class ModelPatcher:
|
||||
# with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@@ -562,7 +164,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
ti_list: List[Any],
|
||||
ti_list: List[Tuple[str, Any]],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
@@ -572,27 +174,27 @@ class ModelPatcher:
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
def _get_trigger(ti_name, index):
|
||||
trigger = ti_name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for ti_name, ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti in ti_list:
|
||||
for ti_name, ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i]
|
||||
trigger = _get_trigger(ti, i)
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
@@ -637,7 +239,6 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class TextualInversionModel:
|
||||
name: str
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
|
||||
@classmethod
|
||||
@@ -651,7 +252,6 @@ class TextualInversionModel:
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
result.name = file_path.stem # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
@@ -761,7 +361,8 @@ class ONNXModelPatcher:
|
||||
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
# TODO: rewrite to pass original tensor weight(required by ia3)
|
||||
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
@@ -828,7 +429,7 @@ class ONNXModelPatcher:
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Any],
|
||||
ti_list: List[Tuple[str, Any]],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
||||
@@ -841,17 +442,17 @@ class ONNXModelPatcher:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
def _get_trigger(ti_name, index):
|
||||
trigger = ti_name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for ti_name, ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
@@ -861,11 +462,11 @@ class ONNXModelPatcher:
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti in ti_list:
|
||||
for ti_name, ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti, i)
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
|
||||
@@ -28,8 +28,6 @@ import torch
|
||||
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .lora import LoRAModel, TextualInversionModel
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
@@ -188,7 +186,7 @@ class ModelCache(object):
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
|
||||
# this will remove older cached models until
|
||||
|
||||
@@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import textwrap
|
||||
import yaml
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@@ -259,6 +259,7 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
ModelBase,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@@ -361,7 +362,7 @@ class ModelManager(object):
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
@@ -381,18 +382,24 @@ class ModelManager(object):
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
Given a model name, returns True if it is a valid identifier.
|
||||
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param rescan: if True, scan_models_directory
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return model_key in self.models
|
||||
exists = model_key in self.models
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if rescan and not exists:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
||||
|
||||
return exists
|
||||
|
||||
@classmethod
|
||||
def create_key(
|
||||
@@ -443,39 +450,32 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.resolve_model_path(model_config.path)
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
|
||||
if not model_path.exists():
|
||||
if model_class.save_to_config:
|
||||
self.models[model_key].error = ModelError.NotFound
|
||||
raise Exception(f'Files for model "{model_key}" not found')
|
||||
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.app_config.root_path / override_path
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
@@ -513,12 +513,61 @@ class ModelManager(object):
|
||||
_cache=self.cache,
|
||||
)
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None and len(submodel_path) > 0:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||
"""Get a model's config object."""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
try:
|
||||
model_config = self.models[model_key]
|
||||
except KeyError:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
return model_config
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _instantiate(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelBase:
|
||||
"""Make a new instance of this model, without loading it."""
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
# FIXME: do non-overriden submodels get the right class?
|
||||
constructor = self._get_implementation(base_model, model_type)
|
||||
instance = constructor(model_path, base_model, model_type)
|
||||
return instance
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
@@ -540,13 +589,16 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model, model_type, model_name)
|
||||
return models[0] if models else None
|
||||
if len(models) >= 1:
|
||||
return models[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
@@ -560,7 +612,7 @@ class ModelManager(object):
|
||||
|
||||
model_keys = (
|
||||
[self.create_key(model_name, base_model, model_type)]
|
||||
if model_name
|
||||
if model_name and base_model and model_type
|
||||
else sorted(self.models, key=str.casefold)
|
||||
)
|
||||
models = []
|
||||
@@ -596,7 +648,7 @@ class ModelManager(object):
|
||||
Print a table of models and their descriptions. This needs to be redone
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_type, model_dict in self.list_models().items():
|
||||
for model_dict in self.list_models():
|
||||
for model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
@@ -658,7 +710,7 @@ class ModelManager(object):
|
||||
if path := model_attributes.get("path"):
|
||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
@@ -670,7 +722,7 @@ class ModelManager(object):
|
||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||
|
||||
# remove conversion cache as config changed
|
||||
old_model_path = self.app_config.root_path / old_model.path
|
||||
old_model_path = self.resolve_model_path(old_model.path)
|
||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||
if old_model_cache.exists():
|
||||
if old_model_cache.is_dir():
|
||||
@@ -699,8 +751,8 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename or rebase a model.
|
||||
@@ -753,7 +805,7 @@ class ModelManager(object):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@@ -767,6 +819,10 @@ class ModelManager(object):
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
"""
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
|
||||
if info is None:
|
||||
raise FileNotFoundError(f"model not found: {model_name}")
|
||||
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
@@ -780,7 +836,7 @@ class ModelManager(object):
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
checkpoint_path = self.resolve_model_path(info["path"])
|
||||
old_diffusers_path = self.resolve_model_path(model.location)
|
||||
new_diffusers_path = (
|
||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||
@@ -836,7 +892,7 @@ class ModelManager(object):
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
@@ -845,7 +901,7 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
@@ -903,7 +959,7 @@ class ModelManager(object):
|
||||
|
||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
||||
if not model_path.exists():
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
@@ -919,7 +975,7 @@ class ModelManager(object):
|
||||
for cur_model_type in ModelType:
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
|
||||
if not models_dir.exists():
|
||||
@@ -935,7 +991,9 @@ class ModelManager(object):
|
||||
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
||||
|
||||
model_path = self.relative_model_path(model_path)
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
model_config: ModelConfigBase = model_class.probe_config(
|
||||
str(model_path), model_base=cur_base_model
|
||||
)
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except DuplicateModelException as e:
|
||||
@@ -983,7 +1041,7 @@ class ModelManager(object):
|
||||
# LS: hacky
|
||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||
try:
|
||||
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
|
||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -992,7 +1050,7 @@ class ModelManager(object):
|
||||
model_manager=self,
|
||||
prediction_type_helper=ask_user_for_prediction_type,
|
||||
)
|
||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
||||
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||
directories = {
|
||||
config.root_path / x
|
||||
for x in [
|
||||
@@ -1011,7 +1069,7 @@ class ModelManager(object):
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
|
||||
@@ -33,7 +33,7 @@ class ModelMerger(object):
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
@@ -73,7 +73,7 @@ class ModelMerger(object):
|
||||
base_model: Union[BaseModelType, str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
@@ -122,7 +122,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=str(dump_path),
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
|
||||
@@ -17,6 +17,7 @@ from .models import (
|
||||
SilenceWarnings,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .util import lora_token_vector_length
|
||||
from .models.base import read_checkpoint_meta
|
||||
|
||||
|
||||
@@ -315,21 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[0]
|
||||
if key2 in checkpoint
|
||||
else 768
|
||||
)
|
||||
if lora_token_vector_length == 768:
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return None
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
@@ -292,8 +292,9 @@ class DiffusersModel(ModelBase):
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
# print("====ERR LOAD====")
|
||||
# print(f"{variant}: {e}")
|
||||
if not str(e).startswith("Error no file"):
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional, Dict, Union, Literal, Any
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@@ -13,9 +15,6 @@ from .base import (
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
@@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
@@ -87,3 +87,591 @@ class LoRAModel(ModelBase):
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
_device: torch.device
|
||||
_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._name = name
|
||||
self._device = device or torch.cpu
|
||||
self._dtype = dtype or torch.float32
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||
new_state_dict = dict()
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
continue # clip same
|
||||
|
||||
if not full_key.startswith("lora_unet_"):
|
||||
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||
src_key = full_key.replace("lora_unet_", "")
|
||||
try:
|
||||
dst_key = None
|
||||
while "_" in src_key:
|
||||
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||
break
|
||||
src_key = "_".join(src_key.split("_")[:-1])
|
||||
|
||||
if dst_key is None:
|
||||
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||
new_key = full_key.replace(src_key, dst_key)
|
||||
except:
|
||||
print(SDXL_UNET_COMPVIS_MAP)
|
||||
raise
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_COMPVIS_MAP = {
|
||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||
for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
||||
@@ -80,8 +80,10 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
# TO DO: implement picking
|
||||
pass
|
||||
# avoid circular import
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from .base import (
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
@@ -263,6 +264,8 @@ def _convert_ckpt_and_cache(
|
||||
weights = app_config.models_path / model_config.path
|
||||
config_file = app_config.root_path / model_config.config
|
||||
output_path = Path(output_path)
|
||||
variant = model_config.variant
|
||||
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
@@ -289,6 +292,7 @@ def _convert_ckpt_and_cache(
|
||||
original_config_file=config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=weights.suffix == ".safetensors",
|
||||
precision=torch_dtype(choose_torch_device()),
|
||||
**kwargs,
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@@ -18,9 +23,6 @@ from .base import (
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
@@ -80,7 +82,7 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
|
||||
75
invokeai/backend/model_management/util.py
Normal file
75
invokeai/backend/model_management/util.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
"""Utilities used by the Model Manager"""
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
@@ -4,25 +4,21 @@ import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
import einops
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
@@ -31,21 +27,20 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
from ..util import normalize_device
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -289,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_model_group: ModelGroup
|
||||
|
||||
ID_LENGTH = 8
|
||||
|
||||
def __init__(
|
||||
@@ -303,9 +296,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@@ -330,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
@@ -368,72 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._model_group.execution_device
|
||||
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
submodels = []
|
||||
for name in module_names.keys():
|
||||
if hasattr(self, name):
|
||||
value = getattr(self, name)
|
||||
else:
|
||||
value = getattr(self.config, name)
|
||||
if isinstance(value, torch.nn.Module):
|
||||
submodels.append(value)
|
||||
return submodels
|
||||
|
||||
def image_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
:param conditioning_data:
|
||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||
image generation. Can be used to tweak the same generation with different prompts.
|
||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||
:param callback:
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_map_saver,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
@@ -450,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
scheduler_device = self.unet.device
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
@@ -504,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
(batch_size,),
|
||||
timesteps[0],
|
||||
dtype=timesteps.dtype,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
device=self.unet.device,
|
||||
)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
@@ -700,79 +622,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
def img2img_from_embeddings(
|
||||
self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
|
||||
|
||||
# 6. Prepare latent variables
|
||||
initial_latents = self.non_noised_latents_from_image(
|
||||
init_image,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype,
|
||||
)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise,
|
||||
run_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
def img2img_from_latents_and_embeddings(
|
||||
self,
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents
|
||||
if strength < 1.0
|
||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
@@ -780,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
scheduler_device = self.unet.device
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
@@ -806,7 +655,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self._model_group.device_for(self.unet)
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
@@ -877,42 +726,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
return output
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(
|
||||
screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver,
|
||||
)
|
||||
|
||||
def run_safety_checker(self, image, device=None, dtype=None):
|
||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||
if self.safety_checker is not None:
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
return super().decode_latents(latents)
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
from invokeai.backend.image_util import debug_image
|
||||
|
||||
|
||||
@@ -78,10 +78,9 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
cls,
|
||||
self,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
@@ -91,18 +90,19 @@ class InvokeAIDiffuserComponent:
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from accelerate.utils import send_to_device
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
OFFLOAD_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class _NoModel:
|
||||
"""Symbol that indicates no model is loaded.
|
||||
|
||||
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
||||
type-checkable.)
|
||||
"""
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def to(self, device: torch.device):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<NO MODEL>"
|
||||
|
||||
|
||||
NO_MODEL = _NoModel()
|
||||
|
||||
|
||||
class ModelGroup(metaclass=ABCMeta):
|
||||
"""
|
||||
A group of models.
|
||||
|
||||
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
||||
e.g. its text encoder, U-net, VAE, etc.
|
||||
|
||||
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
||||
:py:class:`torch.nn.Module` here.
|
||||
"""
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
self.execution_device = execution_device
|
||||
|
||||
@abstractmethod
|
||||
def install(self, *models: torch.nn.Module):
|
||||
"""Add models to this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall(self, models: torch.nn.Module):
|
||||
"""Remove models from this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall_all(self):
|
||||
"""Remove all models from this group."""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""Load this model to the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_current(self):
|
||||
"""Offload the current model(s) from the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ready(self):
|
||||
"""Ready this group for use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: torch.device):
|
||||
"""Change which device models from this group will execute on."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def device_for(self, model) -> torch.device:
|
||||
"""Get the device the given model will execute on.
|
||||
|
||||
The model should already be a member of this group.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, model):
|
||||
"""Check if the model is a member of this group."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
||||
|
||||
|
||||
class LazilyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
Only one model from this group is loaded on the GPU at a time.
|
||||
|
||||
Running the forward method of a model will displace the previously-loaded model,
|
||||
offloading it to CPU.
|
||||
|
||||
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
||||
you will need to explicitly load it with :py:method:`.load(model)`.
|
||||
|
||||
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
||||
to the appropriate execution device, as long as they are positional arguments and not keyword
|
||||
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
||||
"""
|
||||
|
||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._hooks = weakref.WeakKeyDictionary()
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
hook = self._hooks.pop(model)
|
||||
hook.remove()
|
||||
if self.is_current_model(model):
|
||||
# no longer hooked by this object, so don't claim to manage it
|
||||
self.clear_current_model()
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._hooks.keys())
|
||||
|
||||
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
||||
self.load(module)
|
||||
if len(forward_input) == 0:
|
||||
warnings.warn(
|
||||
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
||||
stacklevel=3,
|
||||
)
|
||||
return send_to_device(forward_input, self.execution_device)
|
||||
|
||||
def load(self, module):
|
||||
if not self.is_current_model(module):
|
||||
self.offload_current()
|
||||
self._load(module)
|
||||
|
||||
def offload_current(self):
|
||||
module = self._current_model_ref()
|
||||
if module is not NO_MODEL:
|
||||
module.to(OFFLOAD_DEVICE)
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
||||
module = module.to(self.execution_device)
|
||||
self.set_current_model(module)
|
||||
return module
|
||||
|
||||
def is_current_model(self, model: torch.nn.Module) -> bool:
|
||||
"""Is the given model the one currently loaded on the execution device?"""
|
||||
return self._current_model_ref() is model
|
||||
|
||||
def is_empty(self):
|
||||
"""Are none of this group's models loaded on the execution device?"""
|
||||
return self._current_model_ref() is NO_MODEL
|
||||
|
||||
def set_current_model(self, value):
|
||||
self._current_model_ref = weakref.ref(value)
|
||||
|
||||
def clear_current_model(self):
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
if device == self.execution_device:
|
||||
return
|
||||
self.execution_device = device
|
||||
current = self._current_model_ref()
|
||||
if current is not NO_MODEL:
|
||||
current.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def ready(self):
|
||||
pass # always ready to load on-demand
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._hooks
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{self.__class__.__name__} object at {id(self):x}: "
|
||||
f"current_model={type(self._current_model_ref()).__name__} >"
|
||||
)
|
||||
|
||||
|
||||
class FullyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
A group of models without any implicit loading or unloading.
|
||||
|
||||
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
||||
"""
|
||||
|
||||
_models: weakref.WeakSet
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._models = weakref.WeakSet()
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.add(model)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.remove(model)
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._models)
|
||||
|
||||
def load(self, model):
|
||||
model.to(self.execution_device)
|
||||
|
||||
def offload_current(self):
|
||||
for model in self._models:
|
||||
model.to(OFFLOAD_DEVICE)
|
||||
|
||||
def ready(self):
|
||||
for model in self._models:
|
||||
self.load(model)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
self.execution_device = device
|
||||
for model in self._models:
|
||||
if model.device != OFFLOAD_DEVICE:
|
||||
model.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._models
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from packaging import version
|
||||
import platform
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
@@ -30,7 +32,7 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.config
|
||||
"""
|
||||
from .invokeai_configure import main as invokeai_configure
|
||||
from .invokeai_update import main as invokeai_update
|
||||
from .model_install import main as invokeai_model_install
|
||||
|
||||
795
invokeai/frontend/install/import_images.py
Normal file
795
invokeai/frontend/install/import_images.py
Normal file
@@ -0,0 +1,795 @@
|
||||
# Copyright (c) 2023 - The InvokeAI Team
|
||||
# Primary Author: David Lovell (github @f412design, discord @techjedi)
|
||||
# co-author, minor tweaks - Lincoln Stein
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=broad-exception-caught
|
||||
"""Script to import images into the new database system for 3.0.0"""
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import locale
|
||||
import sqlite3
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
import uuid
|
||||
import yaml
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import PIL.PngImagePlugin
|
||||
|
||||
from pathlib import Path
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.shortcuts import message_dialog
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
|
||||
@bindings.add("c-c")
|
||||
def _(event):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
# release notes
|
||||
# "Use All" with size dimensions not selectable in the UI will not load dimensions
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration loader."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
INVOKE_DIRNAME = "invokeai"
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
thumbnail_path = None
|
||||
|
||||
def find_and_load(self):
|
||||
"""find the yaml config file and load"""
|
||||
root = app_config.root_path
|
||||
if not self.confirm_and_load(os.path.abspath(root)):
|
||||
print("\r\nSpecify custom database and outputs paths:")
|
||||
self.confirm_and_load_from_user()
|
||||
|
||||
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
|
||||
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
|
||||
|
||||
def confirm_and_load(self, invoke_root):
|
||||
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
|
||||
if os.path.isabs(db_dir):
|
||||
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
else:
|
||||
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
|
||||
|
||||
if os.path.isabs(outdir):
|
||||
outputs_path = os.path.join(outdir, "images")
|
||||
else:
|
||||
outputs_path = os.path.join(invoke_root, outdir, "images")
|
||||
|
||||
db_exists = os.path.exists(database_path)
|
||||
outdir_exists = os.path.exists(outputs_path)
|
||||
|
||||
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
|
||||
text += f"\n Database : {database_path}"
|
||||
text += f"\n Outputs : {outputs_path}"
|
||||
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
|
||||
|
||||
if db_exists and outdir_exists:
|
||||
if (prompt(text).strip() or "Y").upper().startswith("Y"):
|
||||
self.database_path = database_path
|
||||
self.outputs_path = outputs_path
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
|
||||
|
||||
else:
|
||||
message_dialog(
|
||||
title="Path not found",
|
||||
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
|
||||
).run()
|
||||
return False
|
||||
|
||||
def confirm_and_load_from_user(self):
|
||||
default = ""
|
||||
while True:
|
||||
database_path = os.path.expanduser(
|
||||
prompt(
|
||||
"Database: Specify absolute path to the database to import into: ",
|
||||
completer=PathCompleter(
|
||||
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
|
||||
),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
|
||||
break
|
||||
default = database_path + "/" if Path(database_path).is_dir() else database_path
|
||||
|
||||
default = ""
|
||||
while True:
|
||||
outputs_path = os.path.expanduser(
|
||||
prompt(
|
||||
"Outputs: Specify absolute path to outputs/images directory to import into: ",
|
||||
completer=PathCompleter(expanduser=True, only_directories=True),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
|
||||
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
|
||||
break
|
||||
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
|
||||
|
||||
self.database_path = database_path
|
||||
self.outputs_path = outputs_path
|
||||
|
||||
return
|
||||
|
||||
def load_paths_from_yaml(self, yaml_path):
|
||||
"""Load an Invoke AI yaml file and get the database and outputs paths."""
|
||||
try:
|
||||
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
yamlinfo = yaml.safe_load(file)
|
||||
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
|
||||
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
|
||||
return db_dir, outdir
|
||||
except Exception:
|
||||
print(f"Failed to load paths from yaml file! {yaml_path}!")
|
||||
return None, None
|
||||
|
||||
|
||||
class ImportStats:
|
||||
"""DTO for tracking work progress."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
time_start = datetime.datetime.utcnow()
|
||||
count_source_files = 0
|
||||
count_skipped_file_exists = 0
|
||||
count_skipped_db_exists = 0
|
||||
count_imported = 0
|
||||
count_imported_by_version = {}
|
||||
count_file_errors = 0
|
||||
|
||||
@staticmethod
|
||||
def get_elapsed_time_string():
|
||||
"""Get a friendly time string for the time elapsed since processing start."""
|
||||
time_now = datetime.datetime.utcnow()
|
||||
total_seconds = (time_now - ImportStats.time_start).total_seconds()
|
||||
hours = int((total_seconds) / 3600)
|
||||
minutes = int(((total_seconds) % 3600) / 60)
|
||||
seconds = total_seconds % 60
|
||||
out_str = f"{hours} hour(s) -" if hours > 0 else ""
|
||||
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
|
||||
out_str += f"{seconds:.2f} second(s)"
|
||||
return out_str
|
||||
|
||||
|
||||
class InvokeAIMetadata:
|
||||
"""DTO for core Invoke AI generation properties parsed from metadata."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
|
||||
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
|
||||
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
|
||||
return formatted_str
|
||||
|
||||
generation_mode = None
|
||||
steps = None
|
||||
cfg_scale = None
|
||||
model_name = None
|
||||
scheduler = None
|
||||
seed = None
|
||||
width = None
|
||||
height = None
|
||||
rand_device = None
|
||||
strength = None
|
||||
init_image = None
|
||||
positive_prompt = None
|
||||
negative_prompt = None
|
||||
imported_app_version = None
|
||||
|
||||
def to_json(self):
|
||||
"""Convert the active instance to json format."""
|
||||
prop_dict = {}
|
||||
prop_dict["generation_mode"] = self.generation_mode
|
||||
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
|
||||
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
|
||||
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
|
||||
if self.positive_prompt or self.negative_prompt:
|
||||
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
|
||||
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
|
||||
prop_dict["width"] = self.width
|
||||
prop_dict["height"] = self.height
|
||||
# only render seed if it has a value to avoid ui thinking it can set this and then error
|
||||
if self.seed:
|
||||
prop_dict["seed"] = self.seed
|
||||
prop_dict["rand_device"] = self.rand_device
|
||||
prop_dict["cfg_scale"] = self.cfg_scale
|
||||
prop_dict["steps"] = self.steps
|
||||
prop_dict["scheduler"] = self.scheduler
|
||||
prop_dict["clip_skip"] = 0
|
||||
prop_dict["model"] = {}
|
||||
prop_dict["model"]["model_name"] = self.model_name
|
||||
prop_dict["model"]["base_model"] = None
|
||||
prop_dict["controlnets"] = []
|
||||
prop_dict["loras"] = []
|
||||
prop_dict["vae"] = None
|
||||
prop_dict["strength"] = self.strength
|
||||
prop_dict["init_image"] = self.init_image
|
||||
prop_dict["positive_style_prompt"] = None
|
||||
prop_dict["negative_style_prompt"] = None
|
||||
prop_dict["refiner_model"] = None
|
||||
prop_dict["refiner_cfg_scale"] = None
|
||||
prop_dict["refiner_steps"] = None
|
||||
prop_dict["refiner_scheduler"] = None
|
||||
prop_dict["refiner_aesthetic_store"] = None
|
||||
prop_dict["refiner_start"] = None
|
||||
prop_dict["imported_app_version"] = self.imported_app_version
|
||||
|
||||
return json.dumps(prop_dict)
|
||||
|
||||
|
||||
class InvokeAIMetadataParser:
|
||||
"""Parses strings with json data to find Invoke AI core metadata properties."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse_meta_tag_dream(self, dream_string):
|
||||
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = "pre1.15"
|
||||
seed_match = re.search("-S\\s*(\\d+)", dream_string)
|
||||
if seed_match is not None:
|
||||
try:
|
||||
props.seed = int(seed_match[1])
|
||||
except ValueError:
|
||||
props.seed = None
|
||||
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
|
||||
else:
|
||||
raw_prompt = dream_string
|
||||
|
||||
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
|
||||
|
||||
props.positive_prompt = pos_prompt
|
||||
props.negative_prompt = neg_prompt
|
||||
|
||||
return props
|
||||
|
||||
def parse_meta_tag_sd_metadata(self, tag_value):
|
||||
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = tag_value.get("app_version")
|
||||
props.model_name = tag_value.get("model_weights")
|
||||
img_node = tag_value.get("image")
|
||||
if img_node is not None:
|
||||
props.generation_mode = img_node.get("type")
|
||||
props.width = img_node.get("width")
|
||||
props.height = img_node.get("height")
|
||||
props.seed = img_node.get("seed")
|
||||
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
|
||||
props.cfg_scale = img_node.get("cfg_scale")
|
||||
props.steps = img_node.get("steps")
|
||||
props.scheduler = self.map_scheduler(img_node.get("sampler"))
|
||||
props.strength = img_node.get("strength")
|
||||
if props.strength is None:
|
||||
props.strength = img_node.get("strength_steps") # try second name for this property
|
||||
props.init_image = img_node.get("init_image_path")
|
||||
if props.init_image is None: # try second name for this property
|
||||
props.init_image = img_node.get("init_img")
|
||||
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
|
||||
if props.init_image is not None:
|
||||
props.init_image = os.path.basename(props.init_image)
|
||||
raw_prompt = img_node.get("prompt")
|
||||
if isinstance(raw_prompt, list):
|
||||
raw_prompt = raw_prompt[0].get("prompt")
|
||||
|
||||
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
|
||||
|
||||
return props
|
||||
|
||||
def parse_meta_tag_invokeai(self, tag_value):
|
||||
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = "3.0.0 or later"
|
||||
props.generation_mode = tag_value.get("type")
|
||||
if props.generation_mode is not None:
|
||||
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
|
||||
|
||||
props.width = tag_value.get("width")
|
||||
props.height = tag_value.get("height")
|
||||
props.seed = tag_value.get("seed")
|
||||
props.cfg_scale = tag_value.get("cfg_scale")
|
||||
props.steps = tag_value.get("steps")
|
||||
props.scheduler = tag_value.get("scheduler")
|
||||
props.strength = tag_value.get("strength")
|
||||
props.positive_prompt = tag_value.get("positive_conditioning")
|
||||
props.negative_prompt = tag_value.get("negative_conditioning")
|
||||
|
||||
return props
|
||||
|
||||
def map_scheduler(self, old_scheduler):
|
||||
"""Convert the legacy sampler names to matching 3.0 schedulers"""
|
||||
if old_scheduler is None:
|
||||
return None
|
||||
|
||||
match (old_scheduler):
|
||||
case "ddim":
|
||||
return "ddim"
|
||||
case "plms":
|
||||
return "pnmd"
|
||||
case "k_lms":
|
||||
return "lms"
|
||||
case "k_dpm_2":
|
||||
return "kdpm_2"
|
||||
case "k_dpm_2_a":
|
||||
return "kdpm_2_a"
|
||||
case "dpmpp_2":
|
||||
return "dpmpp_2s"
|
||||
case "k_dpmpp_2":
|
||||
return "dpmpp_2m"
|
||||
case "k_dpmpp_2_a":
|
||||
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
case "k_euler":
|
||||
return "euler"
|
||||
case "k_euler_a":
|
||||
return "euler_a"
|
||||
case "k_heun":
|
||||
return "heun"
|
||||
return None
|
||||
|
||||
def split_prompt(self, raw_prompt: str):
|
||||
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
|
||||
if raw_prompt is None:
|
||||
return "", ""
|
||||
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
|
||||
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
|
||||
if len(matches) > 0:
|
||||
negative_prompt = ""
|
||||
if len(matches) == 1:
|
||||
negative_prompt = matches[0].strip().strip(",")
|
||||
else:
|
||||
for match in matches:
|
||||
negative_prompt += f"({match.strip().strip(',')})"
|
||||
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
|
||||
else:
|
||||
positive_prompt = raw_prompt_search.strip()
|
||||
negative_prompt = ""
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
class DatabaseMapper:
|
||||
"""Class to abstract database functionality."""
|
||||
|
||||
def __init__(self, database_path, database_backup_dir):
|
||||
self.database_path = database_path
|
||||
self.database_backup_dir = database_backup_dir
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
|
||||
def connect(self):
|
||||
"""Open connection to the database."""
|
||||
self.connection = sqlite3.connect(self.database_path)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def get_board_names(self):
|
||||
"""Get a list of the current board names from the database."""
|
||||
sql_get_board_name = "SELECT board_name FROM boards"
|
||||
self.cursor.execute(sql_get_board_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def does_image_exist(self, image_name):
|
||||
"""Check database if a image name already exists and return a boolean."""
|
||||
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return True if len(rows) > 0 else False
|
||||
|
||||
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
|
||||
"""Add an image to the database."""
|
||||
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
|
||||
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
|
||||
self.cursor.execute(sql_add_image)
|
||||
self.connection.commit()
|
||||
|
||||
def get_board_id_with_create(self, board_name):
|
||||
"""Get the board id for supplied name, and create the board if one does not exist."""
|
||||
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
|
||||
self.cursor.execute(sql_find_board)
|
||||
rows = self.cursor.fetchall()
|
||||
if len(rows) > 0:
|
||||
return rows[0][0]
|
||||
else:
|
||||
board_date_string = datetime.datetime.utcnow().date().isoformat()
|
||||
new_board_id = str(uuid.uuid4())
|
||||
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
|
||||
self.cursor.execute(sql_insert_board)
|
||||
self.connection.commit()
|
||||
return new_board_id
|
||||
|
||||
def add_image_to_board(self, filename, board_id):
|
||||
"""Add an image mapping to a board."""
|
||||
add_datetime_str = datetime.datetime.utcnow().isoformat()
|
||||
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
|
||||
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
|
||||
self.cursor.execute(sql_add_image_to_board)
|
||||
self.connection.commit()
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the db, cleaning up connections and cursors."""
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
def backup(self, timestamp_string):
|
||||
"""Take a backup of the database."""
|
||||
if not os.path.exists(self.database_backup_dir):
|
||||
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
|
||||
os.makedirs(self.database_backup_dir)
|
||||
print("Done!")
|
||||
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
|
||||
print(f"Making DB Backup at {database_backup_path}...", end="")
|
||||
shutil.copy2(self.database_path, database_backup_path)
|
||||
print("Done!")
|
||||
|
||||
|
||||
class MediaImportProcessor:
|
||||
"""Containing class for script functionality."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
board_name_id_map = {}
|
||||
|
||||
def get_import_file_list(self):
|
||||
"""Ask the user for the import folder and scan for the list of files to return."""
|
||||
while True:
|
||||
default = ""
|
||||
while True:
|
||||
import_dir = os.path.expanduser(
|
||||
prompt(
|
||||
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
|
||||
completer=PathCompleter(expanduser=True, only_directories=True),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
if len(import_dir) > 0 and Path(import_dir).is_dir():
|
||||
break
|
||||
default = import_dir
|
||||
|
||||
recurse_directories = (
|
||||
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
|
||||
)
|
||||
if recurse_directories:
|
||||
is_recurse = False
|
||||
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
|
||||
else:
|
||||
is_recurse = True
|
||||
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
|
||||
|
||||
if len(matching_file_list) > 0:
|
||||
return import_dir, is_recurse, matching_file_list
|
||||
else:
|
||||
print(f"The specific path {import_dir} exists, but does not contain .png files!")
|
||||
|
||||
def get_file_details(self, filepath):
|
||||
"""Retrieve the embedded metedata fields and dimensions from an image file."""
|
||||
with PIL.Image.open(filepath) as img:
|
||||
img.load()
|
||||
png_width, png_height = img.size
|
||||
img_info = img.info
|
||||
return img_info, png_width, png_height
|
||||
|
||||
def select_board_option(self, board_names, timestamp_string):
|
||||
"""Allow the user to choose how a board is selected for imported files."""
|
||||
while True:
|
||||
print("\r\nOptions for board selection for imported images:")
|
||||
print(f"1) Select an existing board name. (found {len(board_names)})")
|
||||
print("2) Specify a board name to create/add to.")
|
||||
print("3) Create/add to board named 'IMPORT'.")
|
||||
print(
|
||||
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
|
||||
)
|
||||
print(
|
||||
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
|
||||
)
|
||||
input_option = input("Specify desired board option: ")
|
||||
match (input_option):
|
||||
case "1":
|
||||
if len(board_names) < 1:
|
||||
print("\r\nThere are no existing board names to choose from. Select another option!")
|
||||
continue
|
||||
board_name = self.select_item_from_list(
|
||||
board_names, "board name", True, "Cancel, go back and choose a different board option."
|
||||
)
|
||||
if board_name is not None:
|
||||
return board_name
|
||||
case "2":
|
||||
while True:
|
||||
board_name = input("Specify new/existing board name: ")
|
||||
if board_name:
|
||||
return board_name
|
||||
case "3":
|
||||
return "IMPORT"
|
||||
case "4":
|
||||
return f"IMPORT_{timestamp_string}"
|
||||
case "5":
|
||||
return "IMPORT_APPVERSION"
|
||||
|
||||
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
|
||||
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
|
||||
print(f"Select a {entity_name.lower()} from the following list:")
|
||||
index = 1
|
||||
for item in items:
|
||||
print(f"{index}) {item}")
|
||||
index += 1
|
||||
if allow_cancel:
|
||||
print(f"{index}) {cancel_string}")
|
||||
while True:
|
||||
try:
|
||||
option_number = int(input("Specify number of selection: "))
|
||||
except ValueError:
|
||||
continue
|
||||
if allow_cancel and option_number == index:
|
||||
return None
|
||||
if option_number >= 1 and option_number <= len(items):
|
||||
return items[option_number - 1]
|
||||
|
||||
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
|
||||
"""Import a single file by its path"""
|
||||
parser = InvokeAIMetadataParser()
|
||||
file_name = os.path.basename(filepath)
|
||||
file_destination_path = os.path.join(config.outputs_path, file_name)
|
||||
|
||||
print("===============================================================================")
|
||||
print(f"Importing {filepath}")
|
||||
|
||||
# check destination to see if the file was previously imported
|
||||
if os.path.exists(file_destination_path):
|
||||
print("File already exists in the destination, skipping!")
|
||||
ImportStats.count_skipped_file_exists += 1
|
||||
return
|
||||
|
||||
# check if file name is already referenced in the database
|
||||
if db_mapper.does_image_exist(file_name):
|
||||
print("A reference to a file with this name already exists in the database, skipping!")
|
||||
ImportStats.count_skipped_db_exists += 1
|
||||
return
|
||||
|
||||
# load image info and dimensions
|
||||
img_info, png_width, png_height = self.get_file_details(filepath)
|
||||
|
||||
# parse metadata
|
||||
destination_needs_meta_update = True
|
||||
log_version_note = "(Unknown)"
|
||||
if "invokeai_metadata" in img_info:
|
||||
# for the latest, we will just re-emit the same json, no need to parse/modify
|
||||
converted_field = None
|
||||
latest_json_string = img_info.get("invokeai_metadata")
|
||||
log_version_note = "3.0.0+"
|
||||
destination_needs_meta_update = False
|
||||
else:
|
||||
if "sd-metadata" in img_info:
|
||||
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
|
||||
elif "invokeai" in img_info:
|
||||
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
|
||||
elif "dream" in img_info:
|
||||
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
|
||||
elif "Dream" in img_info:
|
||||
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
|
||||
else:
|
||||
converted_field = InvokeAIMetadata()
|
||||
destination_needs_meta_update = False
|
||||
print("File does not have metadata from known Invoke AI versions, add only, no update!")
|
||||
|
||||
# use the loaded img dimensions if the metadata didnt have them
|
||||
if converted_field.width is None:
|
||||
converted_field.width = png_width
|
||||
if converted_field.height is None:
|
||||
converted_field.height = png_height
|
||||
|
||||
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
|
||||
log_version_note = log_version_note or "NoVersion"
|
||||
|
||||
latest_json_string = converted_field.to_json()
|
||||
|
||||
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
|
||||
|
||||
# if metadata needs update, then update metdata and copy in one shot
|
||||
if destination_needs_meta_update:
|
||||
print("Updating metadata while copying...", end="")
|
||||
self.update_file_metadata_while_copying(
|
||||
filepath, file_destination_path, "invokeai_metadata", latest_json_string
|
||||
)
|
||||
print("Done!")
|
||||
else:
|
||||
print("No metadata update necessary, copying only...", end="")
|
||||
shutil.copy2(filepath, file_destination_path)
|
||||
print("Done!")
|
||||
|
||||
# create thumbnail
|
||||
print("Creating thumbnail...", end="")
|
||||
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
|
||||
thumbnail_size = 256, 256
|
||||
with PIL.Image.open(filepath) as source_image:
|
||||
source_image.thumbnail(thumbnail_size)
|
||||
source_image.save(thumbnail_path, "webp")
|
||||
print("Done!")
|
||||
|
||||
# finalize the dynamic board name if there is an APPVERSION token in it.
|
||||
if converted_field is not None:
|
||||
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
|
||||
else:
|
||||
board_name = board_name_option.replace("APPVERSION", "Latest")
|
||||
|
||||
# maintain a map of alrady created/looked up ids to avoid DB queries
|
||||
print("Finding/Creating board...", end="")
|
||||
if board_name in self.board_name_id_map:
|
||||
board_id = self.board_name_id_map[board_name]
|
||||
else:
|
||||
board_id = db_mapper.get_board_id_with_create(board_name)
|
||||
self.board_name_id_map[board_name] = board_id
|
||||
print("Done!")
|
||||
|
||||
# add image to db
|
||||
print("Adding image to database......", end="")
|
||||
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
|
||||
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
|
||||
print("Done!")
|
||||
|
||||
# add image to board
|
||||
print("Adding image to board......", end="")
|
||||
db_mapper.add_image_to_board(file_name, board_id)
|
||||
print("Done!")
|
||||
|
||||
ImportStats.count_imported += 1
|
||||
if log_version_note in ImportStats.count_imported_by_version:
|
||||
ImportStats.count_imported_by_version[log_version_note] += 1
|
||||
else:
|
||||
ImportStats.count_imported_by_version[log_version_note] = 1
|
||||
|
||||
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
|
||||
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
|
||||
with PIL.Image.open(filepath) as target_image:
|
||||
existing_img_info = target_image.info
|
||||
metadata = PIL.PngImagePlugin.PngInfo()
|
||||
# re-add any existing invoke ai tags unless they are the one we are trying to add
|
||||
for key in existing_img_info:
|
||||
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
|
||||
metadata.add_text(key, existing_img_info[key])
|
||||
metadata.add_text(tag_name, tag_value)
|
||||
target_image.save(file_destination_path, pnginfo=metadata)
|
||||
|
||||
def process(self):
|
||||
"""Begin main processing."""
|
||||
|
||||
print("===============================================================================")
|
||||
print("This script will import images generated by earlier versions of")
|
||||
print("InvokeAI into the currently installed root directory:")
|
||||
print(f" {app_config.root_path}")
|
||||
print("If this is not what you want to do, type ctrl-C now to cancel.")
|
||||
|
||||
# load config
|
||||
print("===============================================================================")
|
||||
print("= Configuration & Settings")
|
||||
|
||||
config = Config()
|
||||
config.find_and_load()
|
||||
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
|
||||
db_mapper.connect()
|
||||
|
||||
import_dir, is_recurse, import_file_list = self.get_import_file_list()
|
||||
ImportStats.count_source_files = len(import_file_list)
|
||||
|
||||
board_names = db_mapper.get_board_names()
|
||||
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
|
||||
|
||||
print("\r\n===============================================================================")
|
||||
print("= Import Settings Confirmation")
|
||||
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Import Image Source Directory : {import_dir}")
|
||||
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
|
||||
print(f"Count of .png file(s) found : {len(import_file_list)}")
|
||||
print(f"Board name option specified : {board_name_option}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
|
||||
print("\r\nNotes about the import process:")
|
||||
print("- Source image files will not be modified, only copied to the outputs directory.")
|
||||
print("- If the same file name already exists in the destination, the file will be skipped.")
|
||||
print("- If the same file name already has a record in the database, the file will be skipped.")
|
||||
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
|
||||
print(
|
||||
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
|
||||
)
|
||||
print(
|
||||
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
|
||||
)
|
||||
print(
|
||||
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
|
||||
)
|
||||
|
||||
while True:
|
||||
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
|
||||
if should_continue == "n":
|
||||
print("\r\nCancelling Import")
|
||||
return
|
||||
elif should_continue == "y":
|
||||
print()
|
||||
break
|
||||
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
|
||||
print()
|
||||
ImportStats.time_start = datetime.datetime.utcnow()
|
||||
|
||||
for filepath in import_file_list:
|
||||
try:
|
||||
self.import_image(filepath, board_name_option, db_mapper, config)
|
||||
except sqlite3.Error as sql_ex:
|
||||
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
|
||||
print("Exception detail:")
|
||||
print(sql_ex)
|
||||
ImportStats.count_file_errors += 1
|
||||
except Exception as ex:
|
||||
print(f"Exception processing {filepath}, will continue to next file. ")
|
||||
print("Exception detail:")
|
||||
print(ex)
|
||||
ImportStats.count_file_errors += 1
|
||||
|
||||
print("\r\n===============================================================================")
|
||||
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
|
||||
print()
|
||||
print(f"Source File(s) : {ImportStats.count_source_files}")
|
||||
print(f"Total Imported : {ImportStats.count_imported}")
|
||||
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
|
||||
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
|
||||
print(f"Errors during import : {ImportStats.count_file_errors}")
|
||||
if ImportStats.count_imported > 0:
|
||||
print("\r\nBreakdown of imported files by version:")
|
||||
for key, version in ImportStats.count_imported_by_version.items():
|
||||
print(f" {key:20} : {version}")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
processor = MediaImportProcessor()
|
||||
processor.process()
|
||||
except KeyboardInterrupt:
|
||||
print("\r\n\r\nUser cancelled execution.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,4 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
from ...backend.install.invokeai_configure import main
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure
|
||||
|
||||
@@ -28,7 +28,6 @@ from npyscreen import widget
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
ModelInstallList,
|
||||
InstallSelections,
|
||||
ModelInstall,
|
||||
SchedulerPredictionType,
|
||||
@@ -41,12 +40,12 @@ from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
BufferBox,
|
||||
FileBox,
|
||||
set_min_terminal_size,
|
||||
select_stable_diffusion_config_file,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
WindowTooSmallException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@@ -156,7 +155,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
BufferBox,
|
||||
name="Log Messages",
|
||||
editable=False,
|
||||
max_height=15,
|
||||
max_height=6,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
@@ -693,7 +692,11 @@ def select_and_download_models(opt: Namespace):
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt)
|
||||
try:
|
||||
installApp.run()
|
||||
@@ -787,6 +790,8 @@ def main():
|
||||
curses.echo()
|
||||
curses.endwin()
|
||||
logger.info("Goodbye! Come back soon.")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
|
||||
|
||||
@@ -21,31 +21,40 @@ MIN_COLS = 130
|
||||
MIN_LINES = 38
|
||||
|
||||
|
||||
class WindowTooSmallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
ts = get_terminal_size()
|
||||
width = max(columns, ts.columns)
|
||||
height = max(lines, ts.lines)
|
||||
|
||||
def set_terminal_size(columns: int, lines: int) -> bool:
|
||||
OS = platform.uname().system
|
||||
if OS == "Windows":
|
||||
pass
|
||||
# not working reliably - ask user to adjust the window
|
||||
# _set_terminal_size_powershell(width,height)
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
_set_terminal_size_unix(width, height)
|
||||
screen_ok = False
|
||||
while not screen_ok:
|
||||
ts = get_terminal_size()
|
||||
width = max(columns, ts.columns)
|
||||
height = max(lines, ts.lines)
|
||||
|
||||
# check whether it worked....
|
||||
ts = get_terminal_size()
|
||||
pause = False
|
||||
if ts.columns < columns:
|
||||
print("\033[1mThis window is too narrow for the user interface.\033[0m")
|
||||
pause = True
|
||||
if ts.lines < lines:
|
||||
print("\033[1mThis window is too short for the user interface.\033[0m")
|
||||
pause = True
|
||||
if pause:
|
||||
input("Maximize the window then press any key to continue..")
|
||||
if OS == "Windows":
|
||||
pass
|
||||
# not working reliably - ask user to adjust the window
|
||||
# _set_terminal_size_powershell(width,height)
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
_set_terminal_size_unix(width, height)
|
||||
|
||||
# check whether it worked....
|
||||
ts = get_terminal_size()
|
||||
if ts.columns < columns or ts.lines < lines:
|
||||
print(
|
||||
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
|
||||
)
|
||||
resp = input(
|
||||
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
|
||||
)
|
||||
if resp.upper().startswith("Q"):
|
||||
break
|
||||
else:
|
||||
screen_ok = True
|
||||
return screen_ok
|
||||
|
||||
|
||||
def _set_terminal_size_powershell(width: int, height: int):
|
||||
@@ -80,14 +89,14 @@ def _set_terminal_size_unix(width: int, height: int):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
if term_cols >= min_cols and term_lines >= min_lines:
|
||||
return
|
||||
return True
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
set_terminal_size(cols, lines)
|
||||
return set_terminal_size(cols, lines)
|
||||
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
@@ -164,7 +173,7 @@ class FloatSlider(npyscreen.Slider):
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
_entry_type = npyscreen.Slider
|
||||
|
||||
|
||||
class SelectColumnBase:
|
||||
|
||||
@@ -382,7 +382,8 @@ def run_cli(args: Namespace):
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
if args.root_dir:
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
||||
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-fd54b7b9.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-fd54b7b9.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
|
||||
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-815faab3.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-e9f8a36e.js";var za=String.raw,Ca=za`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
||||
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
File diff suppressed because one or more lines are too long
151
invokeai/frontend/web/dist/assets/index-815faab3.js
vendored
Normal file
151
invokeai/frontend/web/dist/assets/index-815faab3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-e9f8a36e.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-e9f8a36e.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-815faab3.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -124,7 +124,8 @@
|
||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"images": "Images",
|
||||
"assets": "Assets"
|
||||
"assets": "Assets",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"typegen": "npx ts-node scripts/typegen.ts",
|
||||
"typegen": "node scripts/typegen.js",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
|
||||
@@ -124,7 +124,8 @@
|
||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"images": "Images",
|
||||
"assets": "Assets"
|
||||
"assets": "Assets",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||
|
||||
@@ -4,8 +4,9 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
@@ -16,7 +17,6 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { ReactNode, memo, useEffect } from 'react';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
|
||||
@@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
</Portal>
|
||||
</Grid>
|
||||
<DeleteImageModal />
|
||||
<UpdateImageBoardModal />
|
||||
<ChangeBoardModal />
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
|
||||
@@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'IMAGE_NAMES') {
|
||||
if (props.dragData.payloadType === 'IMAGE_DTOS') {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
@@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
...STYLES,
|
||||
}}
|
||||
>
|
||||
<Heading>{props.dragData.payload.image_names.length}</Heading>
|
||||
<Heading>{props.dragData.payload.imageDTOs.length}</Heading>
|
||||
<Heading size="sm">Images</Heading>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -18,27 +18,32 @@ import {
|
||||
DragStartEvent,
|
||||
TypesafeDraggableData,
|
||||
} from './typesafeDnd';
|
||||
import { logger } from 'app/logging/logger';
|
||||
|
||||
type ImageDndContextProps = PropsWithChildren;
|
||||
|
||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
const [activeDragData, setActiveDragData] =
|
||||
useState<TypesafeDraggableData | null>(null);
|
||||
const log = logger('images');
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleDragStart = useCallback((event: DragStartEvent) => {
|
||||
console.log('dragStart', event.active.data.current);
|
||||
const activeData = event.active.data.current;
|
||||
if (!activeData) {
|
||||
return;
|
||||
}
|
||||
setActiveDragData(activeData);
|
||||
}, []);
|
||||
const handleDragStart = useCallback(
|
||||
(event: DragStartEvent) => {
|
||||
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
||||
const activeData = event.active.data.current;
|
||||
if (!activeData) {
|
||||
return;
|
||||
}
|
||||
setActiveDragData(activeData);
|
||||
},
|
||||
[log]
|
||||
);
|
||||
|
||||
const handleDragEnd = useCallback(
|
||||
(event: DragEndEvent) => {
|
||||
console.log('dragEnd', event.active.data.current);
|
||||
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
||||
const overData = event.over?.data.current;
|
||||
if (!activeDragData || !overData) {
|
||||
return;
|
||||
@@ -46,7 +51,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
dispatch(dndDropped({ overData, activeData: activeDragData }));
|
||||
setActiveDragData(null);
|
||||
},
|
||||
[activeDragData, dispatch]
|
||||
[activeDragData, dispatch, log]
|
||||
);
|
||||
|
||||
const mouseSensor = useSensor(MouseSensor, {
|
||||
|
||||
@@ -11,7 +11,6 @@ import {
|
||||
useDraggable as useOriginalDraggable,
|
||||
useDroppable as useOriginalDroppable,
|
||||
} from '@dnd-kit/core';
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type BaseDropData = {
|
||||
@@ -54,9 +53,13 @@ export type AddToBatchDropData = BaseDropData & {
|
||||
actionType: 'ADD_TO_BATCH';
|
||||
};
|
||||
|
||||
export type MoveBoardDropData = BaseDropData & {
|
||||
actionType: 'MOVE_BOARD';
|
||||
context: { boardId: BoardId };
|
||||
export type AddToBoardDropData = BaseDropData & {
|
||||
actionType: 'ADD_TO_BOARD';
|
||||
context: { boardId: string };
|
||||
};
|
||||
|
||||
export type RemoveFromBoardDropData = BaseDropData & {
|
||||
actionType: 'REMOVE_FROM_BOARD';
|
||||
};
|
||||
|
||||
export type TypesafeDroppableData =
|
||||
@@ -67,7 +70,8 @@ export type TypesafeDroppableData =
|
||||
| NodesImageDropData
|
||||
| AddToBatchDropData
|
||||
| NodesMultiImageDropData
|
||||
| MoveBoardDropData;
|
||||
| AddToBoardDropData
|
||||
| RemoveFromBoardDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
@@ -78,14 +82,12 @@ export type ImageDraggableData = BaseDragData & {
|
||||
payload: { imageDTO: ImageDTO };
|
||||
};
|
||||
|
||||
export type ImageNamesDraggableData = BaseDragData & {
|
||||
payloadType: 'IMAGE_NAMES';
|
||||
payload: { image_names: string[] };
|
||||
export type ImageDTOsDraggableData = BaseDragData & {
|
||||
payloadType: 'IMAGE_DTOS';
|
||||
payload: { imageDTOs: ImageDTO[] };
|
||||
};
|
||||
|
||||
export type TypesafeDraggableData =
|
||||
| ImageDraggableData
|
||||
| ImageNamesDraggableData;
|
||||
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
||||
|
||||
interface UseDroppableTypesafeArguments
|
||||
extends Omit<UseDroppableArguments, 'data'> {
|
||||
@@ -156,14 +158,39 @@ export const isValidDrop = (
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_MULTI_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BATCH':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
||||
case 'MOVE_BOARD': {
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
return currentBoard !== destinationBoard;
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
case 'REMOVE_FROM_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
@@ -172,20 +199,16 @@ export const isValidDrop = (
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id;
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
const isSameBoard = currentBoard === destinationBoard;
|
||||
const isDestinationValid = !currentBoard ? destinationBoard : true;
|
||||
|
||||
return !isSameBoard && isDestinationValid;
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_NAMES') {
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import { store } from 'app/store/store';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import React, {
|
||||
lazy,
|
||||
memo,
|
||||
@@ -7,16 +9,11 @@ import React, {
|
||||
useEffect,
|
||||
} from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import '../../i18n';
|
||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
@@ -37,6 +34,7 @@ const InvokeAIUI = ({
|
||||
config,
|
||||
headerComponent,
|
||||
middleware,
|
||||
projectId,
|
||||
}: Props) => {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
@@ -49,6 +47,11 @@ const InvokeAIUI = ({
|
||||
$baseUrl.set(apiUrl);
|
||||
}
|
||||
|
||||
// configure API client project header
|
||||
if (projectId) {
|
||||
$projectId.set(projectId);
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
resetMiddlewares();
|
||||
|
||||
@@ -68,8 +71,9 @@ const InvokeAIUI = ({
|
||||
// Reset the API client token and base url on unmount
|
||||
$baseUrl.set(undefined);
|
||||
$authToken.set(undefined);
|
||||
$projectId.set(undefined);
|
||||
};
|
||||
}, [apiUrl, token, middleware]);
|
||||
}, [apiUrl, token, middleware, projectId]);
|
||||
|
||||
return (
|
||||
<React.StrictMode>
|
||||
@@ -77,9 +81,7 @@ const InvokeAIUI = ({
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<AddImageToBoardContextProvider>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
</AddImageToBoardContextProvider>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
</ImageDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { useAppDispatch } from '../store/storeHooks';
|
||||
|
||||
export type ImageUsage = {
|
||||
isInitialImage: boolean;
|
||||
isCanvasImage: boolean;
|
||||
isNodesImage: boolean;
|
||||
isControlNetImage: boolean;
|
||||
};
|
||||
|
||||
type AddImageToBoardContextValue = {
|
||||
/**
|
||||
* Whether the move image dialog is open.
|
||||
*/
|
||||
isOpen: boolean;
|
||||
/**
|
||||
* Closes the move image dialog.
|
||||
*/
|
||||
onClose: () => void;
|
||||
/**
|
||||
* The image pending movement
|
||||
*/
|
||||
image?: ImageDTO;
|
||||
onClickAddToBoard: (image: ImageDTO) => void;
|
||||
handleAddToBoard: (boardId: string) => void;
|
||||
};
|
||||
|
||||
export const AddImageToBoardContext =
|
||||
createContext<AddImageToBoardContextValue>({
|
||||
isOpen: false,
|
||||
onClose: () => undefined,
|
||||
onClickAddToBoard: () => undefined,
|
||||
handleAddToBoard: () => undefined,
|
||||
});
|
||||
|
||||
type Props = PropsWithChildren;
|
||||
|
||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
// Clean up after deleting or dismissing the modal
|
||||
const closeAndClearImageToDelete = useCallback(() => {
|
||||
setImageToMove(undefined);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
const onClickAddToBoard = useCallback(
|
||||
(image?: ImageDTO) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
setImageToMove(image);
|
||||
onOpen();
|
||||
},
|
||||
[setImageToMove, onOpen]
|
||||
);
|
||||
|
||||
const handleAddToBoard = useCallback(
|
||||
(boardId: string) => {
|
||||
if (imageToMove) {
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
imageDTO: imageToMove,
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
closeAndClearImageToDelete();
|
||||
}
|
||||
},
|
||||
[dispatch, closeAndClearImageToDelete, imageToMove]
|
||||
);
|
||||
|
||||
return (
|
||||
<AddImageToBoardContext.Provider
|
||||
value={{
|
||||
isOpen,
|
||||
image: imageToMove,
|
||||
onClose: closeAndClearImageToDelete,
|
||||
onClickAddToBoard,
|
||||
handleAddToBoard,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</AddImageToBoardContext.Provider>
|
||||
);
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
import { createContext } from 'react';
|
||||
|
||||
type VoidFunc = () => void;
|
||||
|
||||
type ImageUploaderTriggerContextType = VoidFunc | null;
|
||||
|
||||
export const ImageUploaderTriggerContext =
|
||||
createContext<ImageUploaderTriggerContextType>(null);
|
||||
@@ -23,6 +23,6 @@ const serializationDenylist: {
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
const result = omit(data, serializationDenylist[key]);
|
||||
const result = omit(data, serializationDenylist[key] ?? []);
|
||||
return JSON.stringify(result);
|
||||
};
|
||||
|
||||
@@ -27,7 +27,8 @@ import {
|
||||
addImageDeletedFulfilledListener,
|
||||
addImageDeletedPendingListener,
|
||||
addImageDeletedRejectedListener,
|
||||
addRequestedImageDeletionListener,
|
||||
addRequestedSingleImageDeletionListener,
|
||||
addRequestedMultipleImageDeletionListener,
|
||||
} from './listeners/imageDeleted';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import {
|
||||
@@ -111,7 +112,8 @@ addImageUploadedRejectedListener();
|
||||
addInitialImageSelectedListener();
|
||||
|
||||
// Image deleted
|
||||
addRequestedImageDeletionListener();
|
||||
addRequestedSingleImageDeletionListener();
|
||||
addRequestedMultipleImageDeletionListener();
|
||||
addImageDeletedPendingListener();
|
||||
addImageDeletedFulfilledListener();
|
||||
addImageDeletedRejectedListener();
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import {
|
||||
ImageCache,
|
||||
getListImagesUrl,
|
||||
imagesApi,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
|
||||
import { ImageCache } from 'services/api/types';
|
||||
|
||||
export const appStarted = createAction('app/appStarted');
|
||||
|
||||
@@ -34,7 +32,8 @@ export const addFirstListImagesListener = () => {
|
||||
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
dispatch(imageSelected(data.ids[0] as string));
|
||||
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -18,7 +18,9 @@ export const addAppConfigReceivedListener = () => {
|
||||
const infillMethod = getState().generation.infillMethod;
|
||||
|
||||
if (!infill_methods.includes(infillMethod)) {
|
||||
dispatch(setInfillMethod(infill_methods[0]));
|
||||
// if there is no infill method, set it to the first one
|
||||
// if there is no first one... god help us
|
||||
dispatch(setInfillMethod(infill_methods[0] as string));
|
||||
}
|
||||
|
||||
if (!nsfw_methods.includes('nsfw_checker')) {
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { boardsApi } from '../../../../../services/api/endpoints/boards';
|
||||
|
||||
export const addDeleteBoardAndImagesFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { deleted_images } = action.payload;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
} from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addBoardIdSelectedListener = () => {
|
||||
startAppListening({
|
||||
@@ -52,8 +53,9 @@ export const addBoardIdSelectedListener = () => {
|
||||
queryArgs
|
||||
)(getState());
|
||||
|
||||
if (boardImagesData?.ids.length) {
|
||||
dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
|
||||
if (boardImagesData) {
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -26,6 +26,8 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
@@ -33,7 +35,7 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: false,
|
||||
board_id: state.gallery.autoAddBoardId,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
|
||||
@@ -31,15 +31,20 @@ const predicate: AnyListenerPredicate<RootState> = (
|
||||
// do not process if the user just disabled auto-config
|
||||
if (
|
||||
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||
.shouldAutoConfig === true
|
||||
?.shouldAutoConfig === true
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const { controlImage, processorType, shouldAutoConfig } =
|
||||
state.controlNet.controlNets[action.payload.controlNetId];
|
||||
const cn = state.controlNet.controlNets[action.payload.controlNetId];
|
||||
|
||||
if (!cn) {
|
||||
// something is wrong, the controlNet should exist
|
||||
return false;
|
||||
}
|
||||
|
||||
const { controlImage, processorType, shouldAutoConfig } = cn;
|
||||
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
|
||||
// do not process if the action is a model change but the processor settings are dirty
|
||||
return false;
|
||||
|
||||
@@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
const { controlNetId } = action.payload;
|
||||
const controlNet = getState().controlNet.controlNets[controlNetId];
|
||||
|
||||
if (!controlNet.controlImage) {
|
||||
if (!controlNet?.controlImage) {
|
||||
log.error('Unable to process ControlNet image');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,57 +1,72 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
||||
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
/**
|
||||
* Called when the user requests an image deletion
|
||||
*/
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
export const addRequestedSingleImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageDeletionConfirmed,
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const { imageDTO, imageUsage } = action.payload;
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
|
||||
// handle multiples in separate listener
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTO = imageDTOs[0];
|
||||
const imageUsage = imagesUsage[0];
|
||||
|
||||
if (!imageDTO || !imageUsage) {
|
||||
// satisfy noUncheckedIndexedAccess
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const state = getState();
|
||||
const lastSelectedImage =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
if (lastSelectedImage === image_name) {
|
||||
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||
const { data } =
|
||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const ids = data?.ids ?? [];
|
||||
const cachedImageDTOs = data
|
||||
? imagesAdapter.getSelectors().selectAll(data)
|
||||
: [];
|
||||
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex(
|
||||
(i) => i.image_name === image_name
|
||||
);
|
||||
|
||||
const filteredIds = ids.filter((id) => id.toString() !== image_name);
|
||||
const filteredImageDTOs = cachedImageDTOs.filter(
|
||||
(i) => i.image_name !== image_name
|
||||
);
|
||||
|
||||
const newSelectedImageIndex = clamp(
|
||||
deletedImageIndex,
|
||||
0,
|
||||
filteredIds.length - 1
|
||||
filteredImageDTOs.length - 1
|
||||
);
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImageId as string));
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
@@ -97,6 +112,66 @@ export const addRequestedImageDeletionListener = () => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Called when the user requests an image deletion
|
||||
*/
|
||||
export const addRequestedMultipleImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageDeletionConfirmed,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
|
||||
// handle singles in separate listener
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Delete from server
|
||||
await dispatch(
|
||||
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
|
||||
).unwrap();
|
||||
const state = getState();
|
||||
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||
const { data } =
|
||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const newSelectedImageDTO = data
|
||||
? imagesAdapter.getSelectors().selectAll(data)[0]
|
||||
: undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
|
||||
if (imagesUsage.some((i) => i.isCanvasImage)) {
|
||||
dispatch(resetCanvas());
|
||||
}
|
||||
|
||||
if (imagesUsage.some((i) => i.isControlNetImage)) {
|
||||
dispatch(controlNetReset());
|
||||
}
|
||||
|
||||
if (imagesUsage.some((i) => i.isInitialImage)) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
||||
if (imagesUsage.some((i) => i.isNodesImage)) {
|
||||
dispatch(nodeEditorReset());
|
||||
}
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Called when the actual delete request is sent to the server
|
||||
*/
|
||||
|
||||
@@ -6,10 +6,7 @@ import {
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import {
|
||||
imageSelected,
|
||||
imagesAddedToBatch,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -27,19 +24,32 @@ export const addImageDroppedListener = () => {
|
||||
const log = logger('images');
|
||||
const { activeData, overData } = action.payload;
|
||||
|
||||
log.debug({ activeData, overData }, 'Image or selection dropped');
|
||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||
log.debug({ activeData, overData }, 'Image dropped');
|
||||
} else if (activeData.payloadType === 'IMAGE_DTOS') {
|
||||
log.debug(
|
||||
{ activeData, overData },
|
||||
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||
);
|
||||
} else {
|
||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||
}
|
||||
|
||||
// set current image
|
||||
/**
|
||||
* Image dropped on current image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
|
||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||
return;
|
||||
}
|
||||
|
||||
// set initial image
|
||||
/**
|
||||
* Image dropped on initial image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
@@ -49,27 +59,9 @@ export const addImageDroppedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// add image to batch
|
||||
if (
|
||||
overData.actionType === 'ADD_TO_BATCH' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
|
||||
return;
|
||||
}
|
||||
|
||||
// add multiple images to batch
|
||||
if (
|
||||
overData.actionType === 'ADD_TO_BATCH' &&
|
||||
activeData.payloadType === 'IMAGE_NAMES'
|
||||
) {
|
||||
dispatch(imagesAddedToBatch(activeData.payload.image_names));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// set control image
|
||||
/**
|
||||
* Image dropped on ControlNet
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
@@ -85,7 +77,9 @@ export const addImageDroppedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// set canvas image
|
||||
/**
|
||||
* Image dropped on Canvas
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
@@ -95,7 +89,9 @@ export const addImageDroppedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// set nodes image
|
||||
/**
|
||||
* Image dropped on node image field
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
@@ -112,61 +108,36 @@ export const addImageDroppedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// set multiple nodes images (single image handler)
|
||||
if (
|
||||
overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: [activeData.payload.imageDTO],
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// // set multiple nodes images (multiple images handler)
|
||||
/**
|
||||
* TODO
|
||||
* Image selection dropped on node image collection field
|
||||
*/
|
||||
// if (
|
||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES'
|
||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||
// activeData.payload.imageDTO
|
||||
// ) {
|
||||
// const { fieldName, nodeId } = overData.context;
|
||||
// dispatch(
|
||||
// imageCollectionFieldValueChanged({
|
||||
// fieldValueChanged({
|
||||
// nodeId,
|
||||
// fieldName,
|
||||
// value: activeData.payload.image_names.map((image_name) => ({
|
||||
// image_name,
|
||||
// })),
|
||||
// value: [activeData.payload.imageDTO],
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
|
||||
// add image to board
|
||||
/**
|
||||
* Image dropped on user board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'MOVE_BOARD' &&
|
||||
overData.actionType === 'ADD_TO_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
const { boardId } = overData.context;
|
||||
|
||||
// image was droppe on the "NoBoardBoard"
|
||||
if (!boardId) {
|
||||
dispatch(
|
||||
imagesApi.endpoints.removeImageFromBoard.initiate({
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// image was dropped on a user board
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
imageDTO,
|
||||
@@ -176,67 +147,58 @@ export const addImageDroppedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// // add gallery selection to board
|
||||
// if (
|
||||
// overData.actionType === 'MOVE_BOARD' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
||||
// overData.context.boardId
|
||||
// ) {
|
||||
// console.log('adding gallery selection to board');
|
||||
// const board_id = overData.context.boardId;
|
||||
// dispatch(
|
||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
||||
// board_id,
|
||||
// image_names: activeData.payload.image_names,
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
/**
|
||||
* Image dropped on 'none' board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(
|
||||
imagesApi.endpoints.removeImageFromBoard.initiate({
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// // remove gallery selection from board
|
||||
// if (
|
||||
// overData.actionType === 'MOVE_BOARD' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
||||
// overData.context.boardId === null
|
||||
// ) {
|
||||
// console.log('removing gallery selection to board');
|
||||
// dispatch(
|
||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
||||
// image_names: activeData.payload.image_names,
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
/**
|
||||
* Multiple images dropped on user board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_TO_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||
activeData.payload.imageDTOs
|
||||
) {
|
||||
const { imageDTOs } = activeData.payload;
|
||||
const { boardId } = overData.context;
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImagesToBoard.initiate({
|
||||
imageDTOs,
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// // add batch selection to board
|
||||
// if (
|
||||
// overData.actionType === 'MOVE_BOARD' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
||||
// overData.context.boardId
|
||||
// ) {
|
||||
// const board_id = overData.context.boardId;
|
||||
// dispatch(
|
||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
||||
// board_id,
|
||||
// image_names: activeData.payload.image_names,
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
|
||||
// // remove batch selection from board
|
||||
// if (
|
||||
// overData.actionType === 'MOVE_BOARD' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
||||
// overData.context.boardId === null
|
||||
// ) {
|
||||
// dispatch(
|
||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
||||
// image_names: activeData.payload.image_names,
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
/**
|
||||
* Multiple images dropped on 'none' board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||
activeData.payload.imageDTOs
|
||||
) {
|
||||
const { imageDTOs } = activeData.payload;
|
||||
dispatch(
|
||||
imagesApi.endpoints.removeImagesFromBoard.initiate({
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,37 +1,32 @@
|
||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
||||
import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import {
|
||||
imageToDeleteSelected,
|
||||
imagesToDeleteSelected,
|
||||
isModalOpenChanged,
|
||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
} from 'features/deleteImageModal/store/slice';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addImageToDeleteSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageToDeleteSelected,
|
||||
actionCreator: imagesToDeleteSelected,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const imageDTO = action.payload;
|
||||
const imageDTOs = action.payload;
|
||||
const state = getState();
|
||||
const { shouldConfirmOnDelete } = state.system;
|
||||
const imageUsage = selectImageUsage(getState());
|
||||
|
||||
if (!imageUsage) {
|
||||
// should never happen
|
||||
return;
|
||||
}
|
||||
const imagesUsage = selectImageUsage(getState());
|
||||
|
||||
const isImageInUse =
|
||||
imageUsage.isCanvasImage ||
|
||||
imageUsage.isInitialImage ||
|
||||
imageUsage.isControlNetImage ||
|
||||
imageUsage.isNodesImage;
|
||||
imagesUsage.some((i) => i.isCanvasImage) ||
|
||||
imagesUsage.some((i) => i.isInitialImage) ||
|
||||
imagesUsage.some((i) => i.isControlNetImage) ||
|
||||
imagesUsage.some((i) => i.isNodesImage);
|
||||
|
||||
if (shouldConfirmOnDelete || isImageInUse) {
|
||||
dispatch(isModalOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
|
||||
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -2,14 +2,13 @@ import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { startAppListening } from '..';
|
||||
import { imagesApi } from '../../../../../services/api/endpoints/images';
|
||||
import { omit } from 'lodash-es';
|
||||
|
||||
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
|
||||
title: 'Image Uploaded',
|
||||
@@ -41,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
// default action - just upload and alert user
|
||||
if (postUploadAction?.type === 'TOAST') {
|
||||
const { toastOptions } = postUploadAction;
|
||||
if (!autoAddBoardId) {
|
||||
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
||||
} else {
|
||||
// Add this image to the board
|
||||
@@ -121,17 +120,6 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'ADD_TO_BATCH') {
|
||||
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
||||
dispatch(
|
||||
addToast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: 'Added to batch',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
setShouldUseSDXLRefiner,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
@@ -144,8 +144,9 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
const firstModel = vaeModelsAdapter
|
||||
.getSelectors()
|
||||
.selectAll(action.payload)[0];
|
||||
|
||||
if (!firstModel) {
|
||||
// No custom VAEs loaded at all; use the default
|
||||
|
||||
@@ -8,9 +8,10 @@ import {
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import {
|
||||
appSocketInvocationComplete,
|
||||
socketInvocationComplete,
|
||||
@@ -67,7 +68,7 @@ export const addInvocationCompleteEventListener = () => {
|
||||
*/
|
||||
|
||||
const { autoAddBoardId } = gallery;
|
||||
if (autoAddBoardId) {
|
||||
if (autoAddBoardId && autoAddBoardId !== 'none') {
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
board_id: autoAddBoardId,
|
||||
@@ -83,10 +84,7 @@ export const addInvocationCompleteEventListener = () => {
|
||||
categories: IMAGE_CATEGORIES,
|
||||
},
|
||||
(draft) => {
|
||||
const oldTotal = draft.total;
|
||||
const newState = imagesAdapter.addOne(draft, imageDTO);
|
||||
const delta = newState.total - oldTotal;
|
||||
draft.total = draft.total + delta;
|
||||
imagesAdapter.addOne(draft, imageDTO);
|
||||
}
|
||||
)
|
||||
);
|
||||
@@ -94,8 +92,8 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
|
||||
{ type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
|
||||
{ type: 'BoardImagesTotal', id: autoAddBoardId },
|
||||
{ type: 'BoardAssetsTotal', id: autoAddBoardId },
|
||||
])
|
||||
);
|
||||
|
||||
@@ -110,7 +108,7 @@ export const addInvocationCompleteEventListener = () => {
|
||||
} else if (!autoAddBoardId) {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
dispatch(imageSelected(imageDTO.image_name));
|
||||
dispatch(imageSelected(imageDTO));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user